diff options
author | Tai Ly <tai.ly@arm.com> | 2023-03-28 22:06:56 +0000 |
---|---|---|
committer | Tai Ly <tai.ly@arm.com> | 2023-05-05 19:23:15 +0000 |
commit | a4d748b08accce06fab93e2d2b96e499b35ae89b (patch) | |
tree | 20a3957e1f45f65f35d5d67ecce1618659e388f0 /reference_model/src/ops | |
parent | 0c71686875618b2e11290273b7a05b88ef8a8aae (diff) | |
download | reference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz |
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model,
which will cause reference model to convert all floating point tensors
to FP64 tensors and compute all operators accordingly.
Also adds optional -p arguments to test runners tosa_verif_run_tests.py
and tosa_verif_framework_compiler_runner.py to run tests in precise mode
Signed-off-by: Tai Ly <tai.ly@arm.com>
Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
Diffstat (limited to 'reference_model/src/ops')
28 files changed, 998 insertions, 649 deletions
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc index 24bd077..6681d6d 100644 --- a/reference_model/src/ops/activation_funcs.cc +++ b/reference_model/src/ops/activation_funcs.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpClamp<Rank, Dtype>::register_fcn() { // Check Tosa Level @@ -32,9 +32,9 @@ int OpClamp<Rank, Dtype>::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: { InEigenType min = (InEigenType)attribute->min_fp(); InEigenType max = (InEigenType)attribute->max_fp(); @@ -43,8 +43,17 @@ int OpClamp<Rank, Dtype>::register_fcn() this->fcn = [min, max](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(a <= min ? min : a >= max ? max : a); }; } break; - case DType_INT8: - case DType_INT16: + case TOSA_REF_TYPE_FP64: + { + InEigenType min = (InEigenType)attribute->min_fp(); + InEigenType max = (InEigenType)attribute->max_fp(); + ERROR_IF(max < min, "OpClamp: max smaller than min"); + + this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); }; + } + break; + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: { InEigenType min = (InEigenType)attribute->min_int(); InEigenType max = (InEigenType)attribute->max_int(); @@ -53,19 +62,19 @@ int OpClamp<Rank, Dtype>::register_fcn() } break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpClamp<Rank, Dtype>::~OpClamp() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpSigmoid<Rank, Dtype>::register_fcn() { // Check Tosa Level @@ -74,21 +83,24 @@ int OpSigmoid<Rank, Dtype>::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.f / (1.f + (expf(-1.f * a)))); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return (1.L / (1.L + (exp(-1.L * a)))); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpTanh<Rank, Dtype>::register_fcn() { // Check Tosa Level @@ -97,13 +109,16 @@ int OpTanh<Rank, Dtype>::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(tanhf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return tanh(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; @@ -115,11 +130,14 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP64); diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h index 9a697cd..2372fcb 100644 --- a/reference_model/src/ops/activation_funcs.h +++ b/reference_model/src/ops/activation_funcs.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ using namespace tosa; namespace TosaReference { -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpClamp : public UnaryNode<Rank, Dtype> { public: @@ -45,7 +45,7 @@ protected: TosaClampAttribute* attribute; }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpSigmoid : public UnaryNode<Rank, Dtype> { public: @@ -61,7 +61,7 @@ public: virtual int register_fcn(); }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpTanh : public UnaryNode<Rank, Dtype> { public: diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc index a5711eb..8a084c7 100644 --- a/reference_model/src/ops/comparison.cc +++ b/reference_model/src/ops/comparison.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpEqual<Rank, Dtype>::register_fcn() { // Check Tosa Level @@ -31,20 +31,21 @@ int OpEqual<Rank, Dtype>::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: - case DType_INT32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpGreater<Rank, Dtype>::register_fcn() { // Check Tosa Level @@ -53,20 +54,21 @@ int OpGreater<Rank, Dtype>::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: - case DType_INT32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpGreaterEqual<Rank, Dtype>::register_fcn() { // Check Tosa Level @@ -75,14 +77,15 @@ int OpGreaterEqual<Rank, Dtype>::register_fcn() switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: - case DType_INT32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_INT32: + case TOSA_REF_TYPE_FP64: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; @@ -93,13 +96,16 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP64); diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h index 29e6b5a..263df6a 100644 --- a/reference_model/src/ops/comparison.h +++ b/reference_model/src/ops/comparison.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,45 +24,45 @@ using namespace tosa; namespace TosaReference { -template <int Rank, DType Dtype> -class OpEqual : public BinaryNode<Rank, Dtype, DType_BOOL> +template <int Rank, TOSA_REF_TYPE Dtype> +class OpEqual : public BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL> { public: OpEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) - : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, id_) + : BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>(sgt_, Op_EQUAL, id_) { register_fcn(); } using InEigenType = typename GetEigenType<Dtype>::type; - using OutEigenType = typename GetEigenType<DType_BOOL>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type; virtual int register_fcn(); }; -template <int Rank, DType Dtype> -class OpGreater : public BinaryNode<Rank, Dtype, DType_BOOL> +template <int Rank, TOSA_REF_TYPE Dtype> +class OpGreater : public BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL> { public: OpGreater(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) - : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_GREATER, id_) + : BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>(sgt_, Op_GREATER, id_) { register_fcn(); } using InEigenType = typename GetEigenType<Dtype>::type; - using OutEigenType = typename GetEigenType<DType_BOOL>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type; virtual int register_fcn(); }; -template <int Rank, DType Dtype> -class OpGreaterEqual : public BinaryNode<Rank, Dtype, DType_BOOL> +template <int Rank, TOSA_REF_TYPE Dtype> +class OpGreaterEqual : public BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL> { public: OpGreaterEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) - : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, id_) + : BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>(sgt_, Op_EQUAL, id_) { register_fcn(); } using InEigenType = typename GetEigenType<Dtype>::type; - using OutEigenType = typename GetEigenType<DType_BOOL>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type; virtual int register_fcn(); }; diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc index f573d5b..03ad6c6 100644 --- a/reference_model/src/ops/control_flow.cc +++ b/reference_model/src/ops/control_flow.cc @@ -174,8 +174,8 @@ int OpCondIf::checkTensorAttributes() { ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand"); - ERROR_IF(inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0, - "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()], + ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0, + "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNameTOSAREFTYPE(inputs[0]->getDtype()), inputs[0]->getRank()); cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]); @@ -223,9 +223,9 @@ int OpCondIf::checkTensorAttributes() std::string else_block_input_name = else_block->GetInputs()[i]; TosaSerializationTensor* then_block_input = then_block->GetTensorByName(then_block_input_name); TosaSerializationTensor* else_block_input = else_block->GetTensorByName(else_block_input_name); - ERROR_IF(operator_input->getDtype() != then_block_input->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(then_block_input->GetDtype()), "OpCondIf: input tensor type mismatch with then_block input type"); - ERROR_IF(operator_input->getDtype() != else_block_input->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(else_block_input->GetDtype()), "OpCondIf: input tensor type mismatch with else_block input type"); ERROR_IF(operator_input->getRank() != (int32_t)then_block_input->GetShape().size(), "OpCondIf: input tensor rank mismatch with then_block input rank"); @@ -247,9 +247,9 @@ int OpCondIf::checkTensorAttributes() std::string else_block_output_name = else_block->GetOutputs()[i]; TosaSerializationTensor* then_block_output = then_block->GetTensorByName(then_block_output_name); TosaSerializationTensor* else_block_output = else_block->GetTensorByName(else_block_output_name); - ERROR_IF(operator_output->getDtype() != then_block_output->GetDtype(), + ERROR_IF(operator_output->getDtype() != ConvertDType(then_block_output->GetDtype()), "OpCondIf: output tensor type mismatch with then_block output type"); - ERROR_IF(operator_output->getDtype() != else_block_output->GetDtype(), + ERROR_IF(operator_output->getDtype() != ConvertDType(else_block_output->GetDtype()), "OpCondIf: output tensor type mismatch with else_block output type"); ERROR_IF(operator_output->getRank() != (int32_t)then_block_output->GetShape().size(), "OpCondIf: output tensor rank mismatch with then_block output rank"); @@ -364,11 +364,11 @@ int OpWhileLoop::checkTensorAttributes() TosaSerializationTensor* body_block_input = body_block->GetTensorByName(body_block_input_name); TosaSerializationTensor* body_block_output = body_block->GetTensorByName(body_block_output_name); - ERROR_IF(operator_input->getDtype() != cond_block_input->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(cond_block_input->GetDtype()), "OpWhileLoop: input tensor type mismatch with cond_block input type"); - ERROR_IF(operator_input->getDtype() != body_block_input->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_input->GetDtype()), "OpWhileLoop: input tensor type mismatch with body_block input type"); - ERROR_IF(operator_input->getDtype() != body_block_output->GetDtype(), + ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_output->GetDtype()), "OpWhileLoop: input tensor type mismatch with body_block output type"); ERROR_IF(operator_input->getRank() != (int32_t)cond_block_input->GetShape().size(), "OpWhileLoop: input tensor rank mismatch with cond_block input rank"); @@ -399,8 +399,7 @@ int OpWhileLoop::checkTensorAttributes() int OpWhileLoop::eval() { - - TosaReference::Tensor0<bool> cond_output_ctensor(std::string("cond_output"), DType_BOOL, std::vector<int32_t>({})); + TosaReference::Tensor0<bool> cond_output_ctensor("cond_output", DType_BOOL, std::vector<int32_t>({})); cond_output_ctensor.allocate(); std::vector<TosaReference::Tensor*> cond_block_outputs; diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index a189466..442cef8 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -20,7 +20,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -32,14 +32,14 @@ OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Axis); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpConcat<Rank, Dtype>::~OpConcat() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpConcat<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -100,7 +100,7 @@ int OpConcat<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpConcat<Rank, Dtype>::eval() { @@ -124,7 +124,7 @@ int OpConcat<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -136,12 +136,12 @@ OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Pad); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpPad<Rank, Dtype>::~OpPad() { } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpPad<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -185,22 +185,23 @@ int OpPad<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpPad<Rank, Dtype>::eval() { InEigenType pad_value = 0; switch (Dtype) { - case DType_BOOL: - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_BOOL: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: pad_value = (InEigenType)attribute->pad_const_int(); break; - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP64: pad_value = (InEigenType)attribute->pad_const_fp(); break; default: @@ -213,7 +214,7 @@ int OpPad<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int InRank, int OutRank, DType Dtype> +template <int InRank, int OutRank, TOSA_REF_TYPE Dtype> OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -225,14 +226,14 @@ OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Reshape); } -template <int InRank, int OutRank, DType Dtype> +template <int InRank, int OutRank, TOSA_REF_TYPE Dtype> OpReshape<InRank, OutRank, Dtype>::~OpReshape() { if (attribute) delete attribute; } -template <int InRank, int OutRank, DType Dtype> +template <int InRank, int OutRank, TOSA_REF_TYPE Dtype> int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -270,7 +271,7 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes() return 0; } -template <int InRank, int OutRank, DType Dtype> +template <int InRank, int OutRank, TOSA_REF_TYPE Dtype> int OpReshape<InRank, OutRank, Dtype>::eval() { for (int32_t d = 0; d < OutRank; d++) @@ -313,7 +314,7 @@ int OpReshape<InRank, OutRank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -325,14 +326,14 @@ OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Axis); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpReverse<Rank, Dtype>::~OpReverse() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReverse<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -376,7 +377,7 @@ int OpReverse<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReverse<Rank, Dtype>::eval() { out->getTensor() = in->getTensor().reverse(reverse_array); @@ -384,7 +385,7 @@ int OpReverse<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -396,14 +397,14 @@ OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Slice); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpSlice<Rank, Dtype>::~OpSlice() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpSlice<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -449,7 +450,7 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpSlice<Rank, Dtype>::eval() { out->getTensor() = in->getTensor().slice(begin_array, size_array); @@ -457,7 +458,7 @@ int OpSlice<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -469,14 +470,14 @@ OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Tile); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpTileBase<Rank, Dtype>::~OpTileBase() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpTileBase<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -518,14 +519,14 @@ int OpTileBase<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpTile<Rank, Dtype>::eval() { // primary template shouldn't be called - FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]); + FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNameTOSAREFTYPE(Dtype)); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpTile<1, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -537,7 +538,7 @@ int OpTile<1, Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpTile<2, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -553,7 +554,7 @@ int OpTile<2, Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpTile<3, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -573,7 +574,7 @@ int OpTile<3, Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpTile<4, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -597,7 +598,7 @@ int OpTile<4, Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpTile<5, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -626,7 +627,7 @@ int OpTile<5, Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpTile<6, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -659,7 +660,7 @@ int OpTile<6, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -671,13 +672,13 @@ OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Transpose); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpTranspose<Rank, Dtype>::~OpTranspose() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpTranspose<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -727,7 +728,7 @@ int OpTranspose<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpTranspose<Rank, Dtype>::eval() { out->getTensor() = in->getTensor().shuffle(perm_array); @@ -743,6 +744,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16); @@ -751,6 +753,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64); DEF_INSTANTIATE_RESHAPE(OpReshape, FP16); DEF_INSTANTIATE_RESHAPE(OpReshape, BF16); @@ -759,6 +762,7 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT8); DEF_INSTANTIATE_RESHAPE(OpReshape, INT16); DEF_INSTANTIATE_RESHAPE(OpReshape, INT32); DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL); +DEF_INSTANTIATE_RESHAPE(OpReshape, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16); @@ -767,6 +771,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); @@ -775,6 +780,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16); @@ -783,6 +789,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16); @@ -791,6 +798,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); @@ -799,3 +807,4 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64); diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h index 3a6cb0d..94ce248 100644 --- a/reference_model/src/ops/data_layout.h +++ b/reference_model/src/ops/data_layout.h @@ -23,7 +23,7 @@ using namespace tosa; namespace TosaReference { -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpConcat : public GraphNode { public: @@ -45,7 +45,7 @@ protected: TosaReference::TensorTemplate<TOut>* out; }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpPad : public GraphNode { public: @@ -66,7 +66,7 @@ protected: TosaPadAttribute* attribute; }; -template <int InRank, int OutRank, DType Dtype> +template <int InRank, int OutRank, TOSA_REF_TYPE Dtype> class OpReshape : public GraphNode { public: @@ -90,7 +90,7 @@ protected: TosaReference::TensorTemplate<TOut>* out; }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpReverse : public GraphNode { public: @@ -112,7 +112,7 @@ protected: Eigen::array<bool, Rank> reverse_array; }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpSlice : public GraphNode { public: @@ -135,7 +135,7 @@ protected: TosaReference::TensorTemplate<TOut>* out; }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpTileBase : public GraphNode { public: @@ -156,7 +156,7 @@ protected: }; // primary template for op tile -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpTile : public OpTileBase<Rank, Dtype> { public: @@ -170,12 +170,12 @@ protected: // partial specialization for specific rank #define DEF_OP_TILE_RANK(N) \ - template <DType Dtype> \ + template <TOSA_REF_TYPE Dtype> \ class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \ { \ public: \ - OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ - : OpTileBase<N, Dtype>(sgt_, attribute_, id_) \ + OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ + : OpTileBase<N, Dtype>(sgt_, attribute_, id_) \ {} \ \ protected: \ @@ -191,7 +191,7 @@ DEF_OP_TILE_RANK(6) #undef DEF_OP_TILE_RANK -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpTranspose : public GraphNode { public: diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc index f5304a5..b7f987a 100644 --- a/reference_model/src/ops/data_nodes.cc +++ b/reference_model/src/ops/data_nodes.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ int OpConst::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpIdentity<Rank, Dtype>::OpIdentity(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -52,11 +52,11 @@ OpIdentity<Rank, Dtype>::OpIdentity(SubgraphTraverser* sgt_, setRequiredRank(0, 6); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpIdentity<Rank, Dtype>::~OpIdentity() {} -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpIdentity<Rank, Dtype>::checkTensorAttributes() { @@ -78,7 +78,7 @@ int OpIdentity<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpIdentity<Rank, Dtype>::eval() { out->getTensor() = in->getTensor(); @@ -96,3 +96,4 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64); diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h index 8761a08..395c667 100644 --- a/reference_model/src/ops/data_nodes.h +++ b/reference_model/src/ops/data_nodes.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ public: virtual int eval(); }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpIdentity : public GraphNode { public: diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc index 6aa0c0f..c5801e7 100644 --- a/reference_model/src/ops/ewise_binary.cc +++ b/reference_model/src/ops/ewise_binary.cc @@ -22,7 +22,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) @@ -37,11 +37,11 @@ BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_, fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); }; } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase() {} -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes() { // Check Tosa Level @@ -90,7 +90,7 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes() return 0; } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast() { const std::vector<int>& a_shape = a->getShape(); @@ -106,7 +106,7 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast() return 0; } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int BinaryNode<Rank, InDtype, OutDtype>::eval() { this->broadcast(); @@ -124,7 +124,7 @@ int BinaryNode<Rank, InDtype, OutDtype>::eval() } // still need to partial specialize this, or Eigen will throw static assertion -template <DType InDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int BinaryNode<0, InDtype, OutDtype>::eval() { this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn); @@ -132,12 +132,12 @@ int BinaryNode<0, InDtype, OutDtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpAdd<Rank, Dtype>::register_fcn() { switch (InDtype) { - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { int64_t res_in_64 = static_cast<int64_t>(a) + b; int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max()); @@ -146,36 +146,39 @@ int OpAdd<Rank, Dtype>::register_fcn() return static_cast<InEigenType>(res_in_64); }; break; - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a + b); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpArithmeticRightShift<Rank, Dtype>::register_fcn() { bool round = attribute->round(); int32_t num_bits = 0; switch (Dtype) { - case DType_INT8: + case TOSA_REF_TYPE_INT8: num_bits = 8; break; - case DType_INT16: + case TOSA_REF_TYPE_INT16: num_bits = 16; break; - case DType_INT32: + case TOSA_REF_TYPE_INT32: num_bits = 32; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType { @@ -195,69 +198,69 @@ int OpArithmeticRightShift<Rank, Dtype>::register_fcn() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpArithmeticRightShift<Rank, Dtype>::~OpArithmeticRightShift() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpBitwiseAnd<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpBitwiseOr<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpBitwiseXor<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpIntdiv<Rank, Dtype>::register_fcn() { switch (InDtype) { - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value"); int64_t res_in_64 = static_cast<int64_t>(a) / b; @@ -268,47 +271,47 @@ int OpIntdiv<Rank, Dtype>::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpLogicalAnd<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpLogicalLeftShift<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_INT8: + case TOSA_REF_TYPE_INT8: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast<OutEigenType>(static_cast<int8_t>(a << b)); }; break; - case DType_INT16: + case TOSA_REF_TYPE_INT16: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast<OutEigenType>(static_cast<int16_t>(a << b)); }; break; - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]", (int32_t)b); @@ -316,32 +319,32 @@ int OpLogicalLeftShift<Rank, Dtype>::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpLogicalRightShift<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_INT8: + case TOSA_REF_TYPE_INT8: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast<OutEigenType>(static_cast<int8_t>(a) >> b); }; break; - case DType_INT16: + case TOSA_REF_TYPE_INT16: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", (int32_t)b); return static_cast<OutEigenType>(static_cast<int16_t>(a) >> b); }; break; - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]", (int32_t)b); @@ -349,91 +352,96 @@ int OpLogicalRightShift<Rank, Dtype>::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpLogicalOr<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpLogicalXor<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpMaximum<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: - case DType_INT32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP64: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpMinimum<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: - case DType_INT32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP64: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int OpMul<Rank, InDtype, OutDtype>::register_fcn() { int32_t shift = attribute->shift(); switch (InDtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a * b); }; break; - case DType_INT32: + case TOSA_REF_TYPE_FP64: + this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; + break; + case TOSA_REF_TYPE_INT32: this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType { int64_t result; if (shift > 0) @@ -457,8 +465,8 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn() return static_cast<OutEigenType>(result); }; break; - case DType_INT8: - case DType_INT16: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType { OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs; @@ -468,41 +476,44 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return 0; } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> OpMul<Rank, InDtype, OutDtype>::~OpMul() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpPow<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(powf(a, b)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return pow(a, b); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpSub<Rank, Dtype>::register_fcn() { switch (InDtype) { - case DType_INT32: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType { int64_t res_in_64 = static_cast<int64_t>(a) - b; int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max()); @@ -511,19 +522,22 @@ int OpSub<Rank, Dtype>::register_fcn() return static_cast<InEigenType>(res_in_64); }; break; - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a - b); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return 0; } -template <int Rank, DType InDtype> +template <int Rank, TOSA_REF_TYPE InDtype> OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -535,13 +549,13 @@ OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Table); } -template <int Rank, DType InDtype> +template <int Rank, TOSA_REF_TYPE InDtype> OpTable<Rank, InDtype>::~OpTable() { if (attribute) delete attribute; } -template <int Rank, DType InDtype> +template <int Rank, TOSA_REF_TYPE InDtype> int OpTable<Rank, InDtype>::checkTensorAttributes() { // Check Tosa Level @@ -573,12 +587,12 @@ int OpTable<Rank, InDtype>::checkTensorAttributes() return 0; } -template <int Rank, DType InDtype> +template <int Rank, TOSA_REF_TYPE InDtype> int OpTable<Rank, InDtype>::eval() { switch (InDtype) { - case DType_INT8: + case TOSA_REF_TYPE_INT8: this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax); int32_t index = input_truncated - QInMin; @@ -587,7 +601,7 @@ int OpTable<Rank, InDtype>::eval() return value; }); break; - case DType_INT16: + case TOSA_REF_TYPE_INT16: this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { // 1. make sure input is int16 range int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax); @@ -610,7 +624,7 @@ int OpTable<Rank, InDtype>::eval() }); break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype)); } return GraphNode::eval(); @@ -630,11 +644,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BF16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT32, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP64, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16); @@ -672,11 +688,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16); @@ -684,15 +702,18 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP64, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16); @@ -703,3 +724,4 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP64, BOOL); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h index 020ddb5..5f6e531 100644 --- a/reference_model/src/ops/ewise_binary.h +++ b/reference_model/src/ops/ewise_binary.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ namespace TosaReference // the way of registering lambda + .binaryExpr() might sacrifice performance here // but it can avoid partially specialization for combination of {rankN, rank0} x {FP32/INT32, QU8, ...} // needs to revisit if performance becomes a bottleneck here -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> class BinaryNodeBase : public GraphNode { public: @@ -67,7 +67,7 @@ protected: }; // primary class -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype> { public: @@ -86,7 +86,7 @@ public: }; // partial specialization for rank 0 -template <DType InDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype> { public: @@ -100,19 +100,19 @@ public: }; #define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \ - template <int Rank, DType Dtype> \ + template <int Rank, TOSA_REF_TYPE Dtype> \ class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \ { \ public: \ - Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ - : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ + : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_) \ { \ register_fcn(); \ } \ - static constexpr DType InDtype = Dtype; \ - static constexpr DType OutDtype = Dtype; \ - using InEigenType = typename GetEigenType<InDtype>::type; \ - using OutEigenType = typename GetEigenType<OutDtype>::type; \ + static constexpr TOSA_REF_TYPE InDtype = Dtype; \ + static constexpr TOSA_REF_TYPE OutDtype = Dtype; \ + using InEigenType = typename GetEigenType<InDtype>::type; \ + using OutEigenType = typename GetEigenType<OutDtype>::type; \ virtual int register_fcn(); \ }; @@ -133,7 +133,7 @@ DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB) #undef DEF_TEMPLATE_BINARY_OP_DEFAULT -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype> { public: @@ -154,7 +154,7 @@ protected: TosaArithmeticRightShiftAttribute* attribute; }; -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> class OpMul : public BinaryNode<Rank, InDtype, OutDtype> { public: @@ -175,7 +175,7 @@ protected: TosaMulAttribute* attribute; }; -template <int Rank, DType InDtype> +template <int Rank, TOSA_REF_TYPE InDtype> class OpTable : public GraphNode { public: @@ -185,9 +185,11 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - static constexpr DType TableDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16; - static constexpr DType OutDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32; - static constexpr uint32_t TableNumEntries = (InDtype == DType_INT8) ? 256 : 513; + static constexpr TOSA_REF_TYPE TableDtype = + (InDtype == TOSA_REF_TYPE_INT8) ? TOSA_REF_TYPE_INT8 : TOSA_REF_TYPE_INT16; + static constexpr TOSA_REF_TYPE OutDtype = + (InDtype == TOSA_REF_TYPE_INT8) ? TOSA_REF_TYPE_INT8 : TOSA_REF_TYPE_INT32; + static constexpr uint32_t TableNumEntries = (InDtype == TOSA_REF_TYPE_INT8) ? 256 : 513; using InEigenType = typename GetEigenType<InDtype>::type; using TableEigenType = typename GetEigenType<TableDtype>::type; using OutEigenType = typename GetEigenType<OutDtype>::type; diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc index 4d53ae4..090ce29 100644 --- a/reference_model/src/ops/ewise_ternary.cc +++ b/reference_model/src/ops/ewise_ternary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -29,11 +29,11 @@ OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_, setRequiredRank(0, 6); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpSelectBase<Rank, Dtype>::~OpSelectBase() {} -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpSelectBase<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -66,13 +66,13 @@ int OpSelectBase<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpSelectBase<Rank, Dtype>::eval() { FATAL_ERROR("shouldn't be called"); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpSelect<Rank, Dtype>::broadcast() { const std::vector<int>& cond_shape = this->cond->getShape(); @@ -90,7 +90,7 @@ int OpSelect<Rank, Dtype>::broadcast() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpSelect<Rank, Dtype>::eval() { this->broadcast(); @@ -102,7 +102,7 @@ int OpSelect<Rank, Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpSelect<0, Dtype>::eval() { this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor()); @@ -118,6 +118,7 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16); @@ -126,3 +127,4 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP64); diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h index 75a2194..c6970cb 100644 --- a/reference_model/src/ops/ewise_ternary.h +++ b/reference_model/src/ops/ewise_ternary.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ namespace TosaReference // 3. Else_val: Rank N, type=<V> // 4. Result: Rank N, type=<V> // Cond, Then_val, Else_val need to be mutually-broadcastable -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpSelectBase : public GraphNode { public: @@ -39,7 +39,7 @@ public: virtual int checkTensorAttributes(); virtual int eval(); - using CondEigenType = typename GetEigenType<DType_BOOL>::type; + using CondEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type; using InEigenType = typename GetEigenType<Dtype>::type; using TCond = Eigen::Tensor<CondEigenType, Rank>; using TIn = Eigen::Tensor<InEigenType, Rank>; @@ -55,7 +55,7 @@ protected: }; // primary class -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpSelect : public OpSelectBase<Rank, Dtype> { public: @@ -69,7 +69,7 @@ public: }; // partial specialization for rank 0 -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype> { public: diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc index 8dc37e2..514cb84 100644 --- a/reference_model/src/ops/ewise_unary.cc +++ b/reference_model/src/ops/ewise_unary.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_) : GraphNode(sgt_, op_, id_) { @@ -35,11 +35,11 @@ UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64 }; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> UnaryNode<Rank, Dtype>::~UnaryNode() {} -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int UnaryNode<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -69,7 +69,7 @@ int UnaryNode<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int UnaryNode<Rank, Dtype>::eval() { this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn); @@ -77,71 +77,75 @@ int UnaryNode<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpAbs<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP32: // No fpTrunc for FP32 as it is a no-op - case DType_INT32: + case TOSA_REF_TYPE_FP32: // No fpTrunc for FP32 as it is a no-op + case TOSA_REF_TYPE_FP64: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; break; - case DType_FP16: - case DType_BF16: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(a > (InEigenType)0 ? a : (-a)); }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpBitwiseNot<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [](InEigenType a) -> OutEigenType { return ~a; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpCeil<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(ceilf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return ceil(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpClz<Rank, Dtype>::register_fcn() { int32_t num_bits; switch (Dtype) { - case DType_INT32: + case TOSA_REF_TYPE_INT32: num_bits = 32; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } this->fcn = [num_bits](int32_t a) -> int32_t { @@ -163,73 +167,82 @@ int OpClz<Rank, Dtype>::register_fcn() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpExp<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(expf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return exp(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpFloor<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(floorf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return floor(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpLog<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(logf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return log(a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpLogicalNot<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_BOOL: + case TOSA_REF_TYPE_BOOL: this->fcn = [](InEigenType a) -> OutEigenType { return !a; }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpNegate<Rank, Dtype>::OpNegate(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -240,31 +253,37 @@ OpNegate<Rank, Dtype>::OpNegate(SubgraphTraverser* sgt_, register_fcn(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpNegate<Rank, Dtype>::~OpNegate() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpNegate<Rank, Dtype>::register_fcn() { - ERROR_IF(Dtype != DType_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t"); - ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t"); switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { InEigenType result = -(a); return fpTrunc<Dtype>(result); }; break; - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { + OutEigenType result = -(a); + return result; + }; + break; + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: this->fcn = [this](InEigenType a) -> OutEigenType { int64_t res_in_64 = 0L - a; int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max()); @@ -272,7 +291,7 @@ int OpNegate<Rank, Dtype>::register_fcn() REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpNegate: result not in acc type range (int32)"); int64_t max_clip_in_64, min_clip_in_64; - if (Dtype == DType_INT16) + if (Dtype == TOSA_REF_TYPE_INT16) { max_clip_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::max()); min_clip_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::min()); @@ -285,7 +304,7 @@ int OpNegate<Rank, Dtype>::register_fcn() return static_cast<InEigenType>(std::min<int64_t>(max_clip_in_64, std::max<int64_t>(min_clip_in_64, res_in_64))); }; break; - case DType_INT8: + case TOSA_REF_TYPE_INT8: this->fcn = [this](InEigenType a) -> OutEigenType { int64_t res_in_64 = 0 - (a - attribute->input1_zp()); int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max()); @@ -297,41 +316,47 @@ int OpNegate<Rank, Dtype>::register_fcn() }; break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReciprocal<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / a); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / a); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpRsqrt<Rank, Dtype>::register_fcn() { switch (Dtype) { - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / sqrtf(a)); }; break; + case TOSA_REF_TYPE_FP64: + this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / sqrt(a)); }; + break; default: - ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]); + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return 0; @@ -345,11 +370,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); @@ -358,20 +385,24 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); @@ -381,11 +412,14 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP64); diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h index 16a4c88..21ee276 100644 --- a/reference_model/src/ops/ewise_unary.h +++ b/reference_model/src/ops/ewise_unary.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ using namespace tosa; namespace TosaReference { -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class UnaryNode : public GraphNode { public: @@ -45,11 +45,11 @@ protected: }; #define DEF_TEMPLATE_UNARY_OP(Opname, OPNAME) \ - template <int Rank, DType Dtype> \ + template <int Rank, TOSA_REF_TYPE Dtype> \ class Op##Opname : public UnaryNode<Rank, Dtype> \ { \ public: \ - Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ + Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \ : UnaryNode<Rank, Dtype>(sgt_, Op_##OPNAME, id_) \ { \ register_fcn(); \ @@ -75,7 +75,7 @@ DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT) #undef DEF_TEMPLATE_UNARY_OP // Negate is the only unary op with attributes -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpNegate : public UnaryNode<Rank, Dtype> { public: diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index 190b354..ca12cfe 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <DType InDtype, DType OutDtype, typename resize_t> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t> OpResize<InDtype, OutDtype, resize_t>::OpResize(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -35,14 +35,14 @@ OpResize<InDtype, OutDtype, resize_t>::OpResize(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Resize); } -template <DType InDtype, DType OutDtype, typename resize_t> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t> OpResize<InDtype, OutDtype, resize_t>::~OpResize() { if (attribute) delete attribute; } -template <DType InDtype, DType OutDtype, typename resize_t> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t> int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -64,7 +64,8 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes() if (this->mode == ResizeMode_BILINEAR) { - if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16) + if (OutDtype != TOSA_REF_TYPE_INT32 && OutDtype != TOSA_REF_TYPE_INT48 && OutDtype != TOSA_REF_TYPE_FP32 && + OutDtype != TOSA_REF_TYPE_FP16 && OutDtype != TOSA_REF_TYPE_BF16 && OutDtype != TOSA_REF_TYPE_FP64) { printNodeValidationError("OpResize: invalid data type for BILINEAR"); return 1; @@ -72,7 +73,8 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes() } else { - if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16) + if (OutDtype != TOSA_REF_TYPE_INT8 && OutDtype != TOSA_REF_TYPE_INT16 && OutDtype != TOSA_REF_TYPE_FP32 && + OutDtype != TOSA_REF_TYPE_FP16 && OutDtype != TOSA_REF_TYPE_BF16 && OutDtype != TOSA_REF_TYPE_FP64) { printNodeValidationError("OpResize: invalid data type for NEAREST"); return 1; @@ -87,7 +89,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes() return 0; } -template <DType InDtype, DType OutDtype, typename resize_t> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t> int OpResize<InDtype, OutDtype, resize_t>::eval() { int in_batch = in->getShape()[0]; @@ -157,24 +159,38 @@ int OpResize<InDtype, OutDtype, resize_t>::eval() int32_t y = oy * scale_y_d + offset_y; int32_t x = ox * scale_x_d + offset_x; - float fy = static_cast<float>(y) / static_cast<float>(scale_y_n); - float fx = static_cast<float>(x) / static_cast<float>(scale_x_n); - - int32_t iy = floor(fy); - int32_t ix = floor(fx); - + int32_t iy; + int32_t ix; resize_t dy; resize_t dx; - if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) || - (typeid(resize_t) == typeid(half_float::half))) + if (std::is_same<resize_t, double>::value) { - dy = (resize_t)(fy - iy); - dx = (resize_t)(fx - ix); + const double fy_double = static_cast<double>(y) / static_cast<double>(scale_y_n); + const double fx_double = static_cast<double>(x) / static_cast<double>(scale_x_n); + iy = floor(fy_double); + ix = floor(fx_double); + + dy = (resize_t)(fy_double - iy); + dx = (resize_t)(fx_double - ix); } else { - dy = (resize_t)(y - (iy * scale_y_n)); - dx = (resize_t)(x - (ix * scale_x_n)); + const float fy = static_cast<float>(y) / static_cast<float>(scale_y_n); + const float fx = static_cast<float>(x) / static_cast<float>(scale_x_n); + iy = floor(fy); + ix = floor(fx); + + if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) || + (typeid(resize_t) == typeid(half_float::half))) + { + dy = (resize_t)(fy - iy); + dx = (resize_t)(fx - ix); + } + else + { + dy = (resize_t)(y - (iy * scale_y_n)); + dx = (resize_t)(x - (ix * scale_x_n)); + } } int32_t iy0 = MAX(iy, 0); @@ -248,3 +264,4 @@ DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT16, int16_t); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP16, FP16, half_float::half); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, BF16, BF16, Eigen::bfloat16); DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP32, FP32, float); +DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP64, FP64, double); diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h index 508d2c8..6d5a418 100644 --- a/reference_model/src/ops/image.h +++ b/reference_model/src/ops/image.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ using namespace tosa; namespace TosaReference { -template <DType InDtype, DType OutDtype, typename resize_t> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t> class OpResize : public GraphNode { public: diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 1db3974..0a78884 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -37,11 +37,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, Op opType, TosaAttributeBase* attribute, uint64_t id, - DType inputDType, + TOSA_REF_TYPE inputDTYPE, int inputRank, - DType outputDType, + TOSA_REF_TYPE outputDTYPE, int outputRank, - DType weightDType, + TOSA_REF_TYPE weightDTYPE, int weightRank) { switch (opType) @@ -53,6 +53,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64); break; case Op_AVG_POOL2D: DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16); @@ -61,6 +62,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32); DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); + DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP64, FP64); break; case Op_CONV2D: DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16); @@ -70,6 +72,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); break; case Op_CONV3D: DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16); @@ -79,6 +82,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64); break; case Op_DEPTHWISE_CONV2D: DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); @@ -88,9 +92,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); break; case Op_FFT2D: DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32); + DEF_FACTORY_ONE_TYPE(OpFFT2d, FP64); break; case Op_FULLY_CONNECTED: DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); @@ -100,6 +106,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64); break; case Op_MATMUL: DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP16); @@ -108,6 +115,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP32, FP32); DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32); DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP64, FP64); break; case Op_MAX_POOL2D: DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); @@ -115,9 +123,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8); DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP64); break; case Op_RFFT2D: DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32); + DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64); break; case Op_TRANSPOSE_CONV2D: DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); @@ -127,6 +137,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); break; // activation_funcs @@ -136,16 +147,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP64); break; case Op_SIGMOID: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP64); break; case Op_TANH: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP64); break; // ewise_binary @@ -154,6 +168,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP64); break; case Op_ARITHMETIC_RIGHT_SHIFT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8); @@ -202,12 +217,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP64); break; case Op_MINIMUM: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP64); break; case Op_MUL: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16); @@ -216,17 +233,20 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP64, FP64); break; case Op_POW: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP64); break; case Op_SUB: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP64); break; case Op_TABLE: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8); @@ -239,6 +259,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP64); break; case Op_BITWISE_NOT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8); @@ -249,6 +270,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64); break; case Op_CLZ: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); @@ -257,16 +279,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP64); break; case Op_FLOOR: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP64); break; case Op_LOG: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP64); break; case Op_LOGICAL_NOT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); @@ -278,16 +303,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP64); break; case Op_RECIPROCAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP64); break; case Op_RSQRT: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64); break; // ewise_ternary @@ -299,6 +327,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP64); break; // comparison @@ -307,18 +336,21 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP64); break; case Op_GREATER: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP64); break; case Op_GREATER_EQUAL: DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP64); break; // reduction @@ -335,6 +367,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64); break; case Op_REDUCE_MIN: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); @@ -343,16 +376,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64); break; case Op_REDUCE_PRODUCT: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64); break; case Op_REDUCE_SUM: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); break; @@ -365,6 +401,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64); break; case Op_PAD: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); @@ -374,6 +411,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64); break; case Op_RESHAPE: DEF_FACTORY_RESHAPE(OpReshape, FP16); @@ -383,6 +421,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RESHAPE(OpReshape, INT16); DEF_FACTORY_RESHAPE(OpReshape, INT32); DEF_FACTORY_RESHAPE(OpReshape, BOOL); + DEF_FACTORY_RESHAPE(OpReshape, FP64); break; case Op_REVERSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); @@ -392,6 +431,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64); break; case Op_SLICE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); @@ -401,6 +441,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64); break; case Op_TILE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); @@ -410,6 +451,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64); break; case Op_TRANSPOSE: DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); @@ -419,6 +461,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64); break; // scatter_gather @@ -429,6 +472,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpGather, FP16); DEF_FACTORY_ONE_TYPE(OpGather, BF16); DEF_FACTORY_ONE_TYPE(OpGather, FP32); + DEF_FACTORY_ONE_TYPE(OpGather, FP64); break; case Op_SCATTER: DEF_FACTORY_ONE_TYPE(OpScatter, INT8); @@ -437,6 +481,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpScatter, FP16); DEF_FACTORY_ONE_TYPE(OpScatter, BF16); DEF_FACTORY_ONE_TYPE(OpScatter, FP32); + DEF_FACTORY_ONE_TYPE(OpScatter, FP64); break; // image @@ -448,6 +493,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16); DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OpResize, BF16, BF16); DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32); + DEF_FACTORY_TWO_TYPE_RESIZE_FP64(OpResize, FP64, FP64); break; // data_nodes @@ -461,6 +507,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64); break; // type_conversion @@ -499,6 +546,13 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16); DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); break; case Op_RESCALE: DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index 9117df4..f276e03 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,19 +23,19 @@ #define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \ case RANK: \ - return new OP<RANK, DType_##DTYPE>(sgt, attribute, id); + return new OP<RANK, TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id); #define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ case RANK: \ - return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); + return new OP<RANK, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id); #define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \ case RANK2: \ - return new OP<RANK1, RANK2, DType_##DTYPE>(sgt, attribute, id); + return new OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id); #define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \ case RANK2: \ - return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); + return new OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id); #define DEF_FACTORY_ONE_RANK_0_6(OP) \ switch (inputRank) \ @@ -57,40 +57,42 @@ } #define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \ - if (inputDType == DType_##DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \ { \ - return new OP<DType_##DTYPE>(sgt, attribute, id); \ + return new OP<TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id); \ } #define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACCUM_DTYPE) \ - if (inputDType == DType_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \ { \ - return new OP<DType_##DTYPE, DType_##ACCUM_DTYPE>(sgt, attribute, id); \ + return new OP<TOSA_REF_TYPE_##DTYPE, TOSA_REF_TYPE_##ACCUM_DTYPE>(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ - return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \ + return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE_IN_OUT(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ - return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \ + return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \ - if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \ - && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2 && \ + ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \ { \ - return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##ACCUM_DTYPE>(sgt, attribute, id); \ - } \ + return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##ACCUM_DTYPE>(sgt, attribute, \ + id); \ + } #define DEF_FACTORY_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \ - if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 && outputDType == DType_##DTYPE3) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2 && \ + outputDTYPE == TOSA_REF_TYPE_##DTYPE3) \ { \ - return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##DTYPE3>(sgt, attribute, id); \ + return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>(sgt, attribute, id); \ } // Statement-expression to evaluate accumulate attribute in-place @@ -108,35 +110,41 @@ FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute.\nPre-initialization " \ "of this attribute is required in order to determine the accumulate type."); \ } \ - accumDType; \ - }) \ + ConvertDType(accumDType); \ + }) #define DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ - return new OP<DType_##DTYPE1, DType_##DTYPE2, int16_t>(sgt, attribute, id); \ + return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, int16_t>(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ - return new OP<DType_##DTYPE1, DType_##DTYPE2, half_float::half>(sgt, attribute, id); \ + return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, half_float::half>(sgt, attribute, id); \ } #define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ + { \ + return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \ + } + +#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ - return new OP<DType_##DTYPE1, DType_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \ + return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, float>(sgt, attribute, id); \ } -#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ +#define DEF_FACTORY_TWO_TYPE_RESIZE_FP64(OP, DTYPE1, DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ - return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \ + return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, double>(sgt, attribute, id); \ } #define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ - if (inputDType == DType_##DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \ { \ switch (inputRank) \ { \ @@ -151,7 +159,7 @@ } #define DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ - if (inputDType == DType_##DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \ { \ switch (inputRank) \ { \ @@ -165,7 +173,7 @@ } #define DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \ - if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \ { \ switch (inputRank) \ { \ @@ -180,7 +188,7 @@ } #define DEF_FACTORY_RESHAPE(OP, DTYPE) \ - if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && outputDTYPE == TOSA_REF_TYPE_##DTYPE) \ { \ switch (inputRank) \ { \ @@ -292,11 +300,11 @@ public: tosa::Op opType, tosa::TosaAttributeBase* attribute, uint64_t id, - tosa::DType inputDType, + TOSA_REF_TYPE inputDTYPE, int inputRank, - tosa::DType outputDType, + TOSA_REF_TYPE outputDTYPE, int outputRank, - tosa::DType weightDType, + TOSA_REF_TYPE weightDTYPE, int weightRank); }; }; // namespace TosaReference diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc index cd9d55f..bf8ba57 100644 --- a/reference_model/src/ops/reduction.cc +++ b/reference_model/src/ops/reduction.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, op_, id_) { @@ -30,14 +30,14 @@ ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, Tosa INIT_ATTRIBUTE(Axis); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> ReduceNode<Rank, Dtype>::~ReduceNode() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int ReduceNode<Rank, Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -100,7 +100,7 @@ struct AnyReducer { bool finalize(const bool accum) const { return accum; } }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceAll<Rank, Dtype>::eval() { this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions()); @@ -108,7 +108,7 @@ int OpReduceAll<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceAny<Rank, Dtype>::eval() { this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions()); @@ -116,7 +116,7 @@ int OpReduceAny<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceMax<Rank, Dtype>::eval() { this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions()); @@ -124,7 +124,7 @@ int OpReduceMax<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceMin<Rank, Dtype>::eval() { this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions()); @@ -132,35 +132,74 @@ int OpReduceMin<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceProduct<Rank, Dtype>::eval() { switch(Dtype) { - case DType_FP16: - case DType_BF16: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);}); break; - default: + case TOSA_REF_TYPE_FP32: this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()); break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + return GraphNode::eval(); +} + +struct ProductDoubleReducer +{ + static const bool PacketAccess = false; + void reduce(const double val, double* accum) + { + *accum *= val; + } + double initialize() const + { + return 1.0; + } + double finalize(const double accum) const + { + return accum; + } +}; + +template <int Rank, TOSA_REF_TYPE Dtype> +int OpReduceProductDouble<Rank, Dtype>::eval() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_FP64: + this->out->getTensor() = this->in->getTensor() + .reduce(this->dims, ProductDoubleReducer()) + .reshape(this->out->getTensor().dimensions()); + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceSum<Rank, Dtype>::eval() { switch(Dtype) { - case DType_FP16: - case DType_BF16: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);}); break; - default: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_INT32: this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()); break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); } return GraphNode::eval(); @@ -183,7 +222,7 @@ struct SumRequiresReducer { SubgraphTraverser* parent_sgt; }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReduceSumInt<Rank, Dtype>::eval() { this->out->getTensor() = this->in->getTensor().reduce(this->dims, SumRequiresReducer(this->parent_sgt)).reshape(this->out->getTensor().dimensions()); @@ -191,6 +230,40 @@ int OpReduceSumInt<Rank, Dtype>::eval() return GraphNode::eval(); } +struct SumDoubleReducer +{ + static const bool PacketAccess = false; + void reduce(const double val, double* accum) + { + *accum += val; + } + double initialize() const + { + return 0.0; + } + double finalize(const double accum) const + { + return accum; + } +}; + +template <int Rank, TOSA_REF_TYPE Dtype> +int OpReduceSumDouble<Rank, Dtype>::eval() +{ + switch (Dtype) + { + case TOSA_REF_TYPE_FP64: + this->out->getTensor() = this->in->getTensor() + .reduce(this->dims, SumDoubleReducer()) + .reshape(this->out->getTensor().dimensions()); + break; + default: + ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype)); + } + + return GraphNode::eval(); +} + // template explicit instantiation DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); @@ -202,6 +275,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16); @@ -209,12 +283,15 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32); diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h index 6e98a76..aeb9f1d 100644 --- a/reference_model/src/ops/reduction.h +++ b/reference_model/src/ops/reduction.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ using namespace tosa; namespace TosaReference { -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class ReduceNode : public GraphNode { public: @@ -44,7 +44,7 @@ protected: TosaAxisAttribute* attribute; }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpReduceAll : public ReduceNode<Rank, Dtype> { public: @@ -54,7 +54,7 @@ public: virtual int eval(); }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpReduceAny : public ReduceNode<Rank, Dtype> { public: @@ -64,7 +64,7 @@ public: virtual int eval(); }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpReduceMax : public ReduceNode<Rank, Dtype> { public: @@ -74,7 +74,7 @@ public: virtual int eval(); }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpReduceMin : public ReduceNode<Rank, Dtype> { public: @@ -84,7 +84,7 @@ public: virtual int eval(); }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpReduceProduct : public ReduceNode<Rank, Dtype> { public: @@ -94,7 +94,17 @@ public: virtual int eval(); }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> +class OpReduceProductDouble : public ReduceNode<Rank, Dtype> +{ +public: + OpReduceProductDouble(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_PRODUCT, attribute_, id_) + {} + virtual int eval(); +}; + +template <int Rank, TOSA_REF_TYPE Dtype> class OpReduceSum : public ReduceNode<Rank, Dtype> { public: @@ -104,7 +114,7 @@ public: virtual int eval(); }; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpReduceSumInt : public ReduceNode<Rank, Dtype> { public: @@ -114,6 +124,16 @@ public: virtual int eval(); }; +template <int Rank, TOSA_REF_TYPE Dtype> +class OpReduceSumDouble : public ReduceNode<Rank, Dtype> +{ +public: + OpReduceSumDouble(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) + : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_SUM, attribute_, id_) + {} + virtual int eval(); +}; + }; // namespace TosaReference #endif diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc index bcd8ce5..80b6c58 100644 --- a/reference_model/src/ops/scatter_gather.cc +++ b/reference_model/src/ops/scatter_gather.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -29,11 +29,11 @@ OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_, setRequiredOperands(2, 1); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpGather<Dtype>::~OpGather() {} -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpGather<Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -96,7 +96,7 @@ int OpGather<Dtype>::checkTensorAttributes() return 0; } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpGather<Dtype>::eval() { for (int32_t n = 0; n < N; n++) @@ -116,7 +116,7 @@ int OpGather<Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -125,11 +125,11 @@ OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_, setRequiredOperands(3, 1); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpScatter<Dtype>::~OpScatter() {} -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpScatter<Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -199,7 +199,7 @@ int OpScatter<Dtype>::checkTensorAttributes() return 0; } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpScatter<Dtype>::eval() { // Initializes the output tensor with the input value for values that are unchanged by the scatter operation. @@ -229,6 +229,7 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16); DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16); DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpGather, FP64); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8); DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16); @@ -236,3 +237,4 @@ DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16); DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP64); diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h index af09153..fb675a9 100644 --- a/reference_model/src/ops/scatter_gather.h +++ b/reference_model/src/ops/scatter_gather.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ using namespace tosa; namespace TosaReference { -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> class OpGather : public GraphNode { public: @@ -45,7 +45,7 @@ protected: TosaReference::TensorTemplate<TOutput>* output; }; -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> class OpScatter : public GraphNode { public: diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index ece14b1..6dd6e76 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -16,11 +16,10 @@ #ifndef OP_TEMPLATE_TYPES_H #define OP_TEMPLATE_TYPES_H -#include "tosa_generated.h" -#include <Eigen/CXX11/Tensor> +#include "dtype.h" #include "half.hpp" +#include <Eigen/CXX11/Tensor> #include <Eigen/Core> -#include "arith_util.h" using namespace tosa; @@ -64,213 +63,218 @@ using Tensor5 = TensorTemplate<ETensor5<T>>; template <typename T> using Tensor6 = TensorTemplate<ETensor6<T>>; -template <DType type> +template <TOSA_REF_TYPE type> struct GetEigenType; template <> -struct GetEigenType<DType_FP32> +struct GetEigenType<TOSA_REF_TYPE_FP64> +{ + using type = double; +}; +template <> +struct GetEigenType<TOSA_REF_TYPE_FP32> { using type = float; }; template <> -struct GetEigenType<DType_FP16> +struct GetEigenType<TOSA_REF_TYPE_FP16> { // NOTE: full precision used using type = float; }; template <> -struct GetEigenType<DType_BF16> +struct GetEigenType<TOSA_REF_TYPE_BF16> { // NOTE: full precision used using type = float; }; template <> -struct GetEigenType<DType_INT32> +struct GetEigenType<TOSA_REF_TYPE_INT32> { using type = int32_t; }; template <> -struct GetEigenType<DType_INT48> +struct GetEigenType<TOSA_REF_TYPE_INT48> { using type = int64_t; }; template <> -struct GetEigenType<DType_BOOL> +struct GetEigenType<TOSA_REF_TYPE_BOOL> { using type = bool; }; template <> -struct GetEigenType<DType_UINT8> +struct GetEigenType<TOSA_REF_TYPE_UINT8> { using type = int32_t; }; template <> -struct GetEigenType<DType_UINT16> +struct GetEigenType<TOSA_REF_TYPE_UINT16> { using type = int32_t; }; template <> -struct GetEigenType<DType_INT4> +struct GetEigenType<TOSA_REF_TYPE_INT4> { using type = int32_t; }; template <> -struct GetEigenType<DType_INT8> +struct GetEigenType<TOSA_REF_TYPE_INT8> { using type = int32_t; }; template <> -struct GetEigenType<DType_INT16> +struct GetEigenType<TOSA_REF_TYPE_INT16> { using type = int32_t; }; /* Get Accumulate Eigen Type: -Same behaviour as GetEigenType for all DTypes except the -single specialised case of DType_FP16. */ -template <DType Dtype> +Same behaviour as GetEigenType for all DTYPEs except the +single specialised case of TOSA_REF_TYPE_FP16. */ +template <TOSA_REF_TYPE Dtype> struct GetAccEigenType; template <> -struct GetAccEigenType<DType_FP16> +struct GetAccEigenType<TOSA_REF_TYPE_FP16> { using type = half_float::half; }; -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> struct GetAccEigenType { using type = typename GetEigenType<Dtype>::type; }; // Meta function to get number of bits -template <DType T> +template <TOSA_REF_TYPE T> struct GetNumBits { static constexpr int32_t value = 0; }; template <> -struct GetNumBits<DType_BOOL> +struct GetNumBits<TOSA_REF_TYPE_BOOL> { static constexpr int32_t value = 1; }; template <> -struct GetNumBits<DType_UINT8> +struct GetNumBits<TOSA_REF_TYPE_UINT8> { static constexpr int32_t value = 8; }; template <> -struct GetNumBits<DType_UINT16> +struct GetNumBits<TOSA_REF_TYPE_UINT16> { static constexpr int32_t value = 16; }; template <> -struct GetNumBits<DType_INT4> +struct GetNumBits<TOSA_REF_TYPE_INT4> { static constexpr int32_t value = 4; }; template <> -struct GetNumBits<DType_INT8> +struct GetNumBits<TOSA_REF_TYPE_INT8> { static constexpr int32_t value = 8; }; template <> -struct GetNumBits<DType_INT16> +struct GetNumBits<TOSA_REF_TYPE_INT16> { static constexpr int32_t value = 16; }; template <> -struct GetNumBits<DType_INT32> +struct GetNumBits<TOSA_REF_TYPE_INT32> { static constexpr int32_t value = 32; }; template <> -struct GetNumBits<DType_INT48> +struct GetNumBits<TOSA_REF_TYPE_INT48> { static constexpr int32_t value = 48; }; template <> -struct GetNumBits<DType_FP16> +struct GetNumBits<TOSA_REF_TYPE_FP16> { static constexpr int32_t value = 16; }; // Meta function to get quantized min/max in compile time -template <DType T> +template <TOSA_REF_TYPE T> struct GetQMin { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMin<DType_UINT8> +struct GetQMin<TOSA_REF_TYPE_UINT8> { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMin<DType_UINT16> +struct GetQMin<TOSA_REF_TYPE_UINT16> { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMin<DType_INT4> +struct GetQMin<TOSA_REF_TYPE_INT4> { static constexpr int64_t value = INT64_C(-8); }; template <> -struct GetQMin<DType_INT8> +struct GetQMin<TOSA_REF_TYPE_INT8> { static constexpr int64_t value = INT64_C(-128); }; template <> -struct GetQMin<DType_INT16> +struct GetQMin<TOSA_REF_TYPE_INT16> { static constexpr int64_t value = INT64_C(-32768); }; template <> -struct GetQMin<DType_INT32> +struct GetQMin<TOSA_REF_TYPE_INT32> { static constexpr int64_t value = -(INT64_C(1) << 31); }; template <> -struct GetQMin<DType_INT48> +struct GetQMin<TOSA_REF_TYPE_INT48> { static constexpr int64_t value = -(INT64_C(1) << 47); }; -template <DType T> +template <TOSA_REF_TYPE T> struct GetQMax { static constexpr int64_t value = INT64_C(0); }; template <> -struct GetQMax<DType_UINT8> +struct GetQMax<TOSA_REF_TYPE_UINT8> { static constexpr int64_t value = INT64_C(255); }; template <> -struct GetQMax<DType_UINT16> +struct GetQMax<TOSA_REF_TYPE_UINT16> { static constexpr int64_t value = INT64_C(65535); }; template <> -struct GetQMax<DType_INT4> +struct GetQMax<TOSA_REF_TYPE_INT4> { static constexpr int64_t value = INT64_C(7); }; template <> -struct GetQMax<DType_INT8> +struct GetQMax<TOSA_REF_TYPE_INT8> { static constexpr int64_t value = INT64_C(127); }; template <> -struct GetQMax<DType_INT16> +struct GetQMax<TOSA_REF_TYPE_INT16> { static constexpr int64_t value = INT64_C(32767); }; template <> -struct GetQMax<DType_INT32> +struct GetQMax<TOSA_REF_TYPE_INT32> { static constexpr int64_t value = (INT64_C(1) << 31) - 1; }; template <> -struct GetQMax<DType_INT48> +struct GetQMax<TOSA_REF_TYPE_INT48> { static constexpr int64_t value = (INT64_C(1) << 47) - 1; }; diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index b3845df..f8fd323 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -116,14 +116,14 @@ int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute, } int check_conv_attribute(tosa::TosaConvAttribute* attribute, - uint32_t conv_dimension, - std::vector<int32_t> input_shape, - std::vector<int32_t> output_shape, - std::vector<int32_t> weights, - uint32_t offset_kernel, - DType InDtype, - DType WeightDtype, - std::string& msg) + uint32_t conv_dimension, + std::vector<int32_t> input_shape, + std::vector<int32_t> output_shape, + std::vector<int32_t> weights, + uint32_t offset_kernel, + TOSA_REF_TYPE InDtype, + TOSA_REF_TYPE WeightDtype, + std::string& msg) { if (attribute->pad().size() != (2 * conv_dimension)) { @@ -226,11 +226,13 @@ int check_conv_attribute(tosa::TosaConvAttribute* attribute, return 1; } - if (InDtype != DType_INT8 && attribute->input_zp() != 0) { + if (InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0) + { msg = "Input zero point must be zero for non-int8 data"; return 1; } - if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) { + if (WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0) + { msg = "Weight zero point must be zero for non-int8 data"; return 1; } @@ -318,7 +320,7 @@ int check_fft_shape(const std::vector<int32_t>& in_real, return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -330,14 +332,14 @@ OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Axis); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpArgMax<Rank, Dtype>::~OpArgMax() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpArgMax<Rank, Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -355,7 +357,7 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes() return 1; } - if (outputs[0]->getDtype() != DType_INT32) + if (outputs[0]->getDtype() != TOSA_REF_TYPE_INT32) { printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator"); return 1; @@ -400,7 +402,7 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpArgMax<Rank, Dtype>::eval() { Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis()); @@ -410,7 +412,7 @@ int OpArgMax<Rank, Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype, DType AccDtype> +template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype> OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -422,14 +424,14 @@ OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Pool); } -template <DType Dtype, DType AccDtype> +template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype> OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d() { if (attribute) delete attribute; } -template <DType Dtype, DType AccDtype> +template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype> int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -449,8 +451,10 @@ int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes() in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); - ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data"); - ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0, + "OpAvgPool2d: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0, + "OpAvgPool2d: Output zeropoint must be zero for non int8_t data"); std::string msg; if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg)) @@ -466,8 +470,9 @@ int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes() // This calculates the number of padding elements used for each location along an axis // Average pooling only divides by the number of elements used, not including padding. // This function uses left/right, but is also used for vertical padding with top/bottom -template <DType Dtype, DType AccDtype> -ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right) +template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype> +ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d( + int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right) { ETensor1<int32_t> result(out_size); @@ -495,7 +500,7 @@ ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size // assuming input and output tensor have same scales like tflite reference // so no need to scale input and output -template <DType Dtype, DType AccDtype> +template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype> int OpAvgPool2d<Dtype, AccDtype>::eval() { int in_batch = this->in->getShape()[0]; @@ -531,7 +536,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval() LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL"); LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL"); - tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); + TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->accum_dtype()); DEBUG_INFO(OP, "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], " @@ -556,7 +561,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval() pad[3] = std::make_pair(0, 0); ETensor4<InEigenType> input_val = this->in->getTensor(); - if (Dtype == DType_INT8) + if (Dtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); } @@ -604,7 +609,8 @@ int OpAvgPool2d<Dtype, AccDtype>::eval() dm2_h.contract(dm2_w, contract_dims) .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 }) .broadcast(bcast); - if (Dtype != DType_FP32 && Dtype != DType_FP16 && Dtype != DType_BF16) + if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 && + Dtype != TOSA_REF_TYPE_FP64) { try { @@ -632,7 +638,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval() return GraphNode::eval(); } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -644,14 +650,14 @@ OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d() { if (attribute) delete attribute; } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -688,7 +694,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() return 0; } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> int OpConv2d<InDtype, WeightDtype, OutDtype>::eval() { int in_batch = this->input->getShape()[0]; @@ -781,7 +787,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -817,7 +823,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval() // reshape back to [N, H, W, C] this->output->getTensor() = biased_output.reshape(col2im_output_dims); - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -826,7 +832,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval() return GraphNode::eval(); } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -838,14 +844,14 @@ OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d() { if (attribute) delete attribute; } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -882,7 +888,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() return 0; } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> int OpConv3d<InDtype, WeightDtype, OutDtype>::eval() { int in_batch = this->input->getShape()[0]; @@ -959,7 +965,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1020,7 +1026,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval() } } - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1029,10 +1035,10 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval() return GraphNode::eval(); } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1041,14 +1047,14 @@ OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTra INIT_ATTRIBUTE(Conv); } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d() { if (attribute) delete attribute; } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1085,7 +1091,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() return 0; } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval() { int in_batch = this->input->getShape()[0]; @@ -1149,7 +1155,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1205,7 +1211,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval() } } - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1214,10 +1220,10 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval() return GraphNode::eval(); } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) { setRequiredOperands(3, 1); @@ -1226,14 +1232,14 @@ OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTrave INIT_ATTRIBUTE(FullyConnected); } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected() { if (attribute) delete attribute; } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1265,13 +1271,15 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); - ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data"); - ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data"); + ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0, + "OpFullyConnected: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0, + "OpFullyConnected: Weight zeropoint must be zero for non int8_t data"); return 0; } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval() { typedef Eigen::Tensor<int, 1>::DimensionPair DimPair; @@ -1289,7 +1297,7 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1299,7 +1307,7 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval() input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims).template cast<OutEigenType>() + this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast); - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1307,7 +1315,7 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval() return GraphNode::eval(); } -template <DType Dtype, DType OutDtype> +template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype> OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -1319,14 +1327,14 @@ OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(MatMul); } -template <DType Dtype, DType OutDtype> +template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype> OpMatMul<Dtype, OutDtype>::~OpMatMul() { if (attribute) delete attribute; } -template <DType Dtype, DType OutDtype> +template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype> int OpMatMul<Dtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1382,13 +1390,15 @@ int OpMatMul<Dtype, OutDtype>::checkTensorAttributes() } W = b->getShape()[2]; - ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data"); - ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->a_zp() != 0, + "OpMatMul: A zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->b_zp() != 0, + "OpMatMul: B zeropoint must be zero for non int8_t data"); return 0; } -template <DType Dtype, DType OutDtype> +template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype> int OpMatMul<Dtype, OutDtype>::eval() { typedef Eigen::Tensor<int, 1>::DimensionPair DimPair; @@ -1396,7 +1406,7 @@ int OpMatMul<Dtype, OutDtype>::eval() TIn a_val = this->a->getTensor(); TIn b_val = this->b->getTensor(); - if (Dtype == DType_INT8) + if (Dtype == TOSA_REF_TYPE_INT8) { a_val = a_val - (InEigenType)attribute->a_zp(); b_val = b_val - (InEigenType)attribute->b_zp(); @@ -1434,7 +1444,7 @@ int OpMatMul<Dtype, OutDtype>::eval() } } - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1443,7 +1453,7 @@ int OpMatMul<Dtype, OutDtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -1455,14 +1465,14 @@ OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Pool); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpMaxPool2d<Dtype>::~OpMaxPool2d() { if (attribute) delete attribute; } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpMaxPool2d<Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1493,7 +1503,7 @@ int OpMaxPool2d<Dtype>::checkTensorAttributes() return 0; } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpMaxPool2d<Dtype>::eval() { int in_batch = this->in->getShape()[0]; @@ -1586,10 +1596,8 @@ int OpMaxPool2d<Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> -OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) +template <TOSA_REF_TYPE Dtype> +OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_FFT2D, id_) { setRequiredOperands(2, 2); @@ -1598,14 +1606,14 @@ OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(FFT); } -template <DType Dtype> -OpFFT2d<Dtype>::~OpFFT2d() { +template <TOSA_REF_TYPE Dtype> +OpFFT2d<Dtype>::~OpFFT2d() +{ if (attribute) delete attribute; } - -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpFFT2d<Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1643,7 +1651,7 @@ int OpFFT2d<Dtype>::checkTensorAttributes() return 0; } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpFFT2d<Dtype>::eval() { int in_real_batch = this->in_real->getShape()[0]; @@ -1709,7 +1717,7 @@ int OpFFT2d<Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -1719,11 +1727,11 @@ OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_, setRequiredRank(3); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> OpRFFT2d<Dtype>::~OpRFFT2d() {} -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpRFFT2d<Dtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1759,7 +1767,7 @@ int OpRFFT2d<Dtype>::checkTensorAttributes() return 0; } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpRFFT2d<Dtype>::eval() { int32_t in_batch = in->getShape()[0]; @@ -1815,10 +1823,10 @@ int OpRFFT2d<Dtype>::eval() return GraphNode::eval(); } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1827,14 +1835,14 @@ OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTra INIT_ATTRIBUTE(TransposeConv); } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d() { if (attribute) delete attribute; } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1923,13 +1931,15 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() return 1; } - ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data"); - ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data"); + ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0, + "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0, + "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data"); return 0; } -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval() { int in_batch = this->input->getShape()[0]; @@ -1985,7 +1995,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -2040,7 +2050,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval() } } - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -2055,6 +2065,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32); @@ -2062,6 +2073,7 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64); // [in_t, weight_t, out_t] DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16); @@ -2071,6 +2083,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32); @@ -2079,6 +2092,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); @@ -2087,8 +2101,10 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32); @@ -2097,6 +2113,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48); @@ -2104,14 +2121,17 @@ DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); @@ -2120,3 +2140,4 @@ DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index 9ef4a58..df53f2b 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -24,7 +24,7 @@ using namespace tosa; namespace TosaReference { -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> class OpArgMax : public GraphNode { public: @@ -35,7 +35,7 @@ public: virtual int eval(); using InEigenType = typename GetEigenType<Dtype>::type; - using OutEigenType = typename GetEigenType<DType_INT32>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_INT32>::type; using TIn = Eigen::Tensor<InEigenType, Rank>; using TOut = Eigen::Tensor<OutEigenType, Rank - 1>; @@ -45,7 +45,7 @@ protected: TosaReference::TensorTemplate<TOut>* output; }; -template <DType Dtype, DType AccDtype> +template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype> class OpAvgPool2d : public GraphNode { public: @@ -74,7 +74,7 @@ protected: ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> class OpConv2d : public GraphNode { public: @@ -104,7 +104,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> class OpConv3d : public GraphNode { public: @@ -134,7 +134,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> class OpDepthwiseConv2d : public GraphNode { public: @@ -164,7 +164,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> class OpFullyConnected : public GraphNode { public: @@ -195,7 +195,7 @@ protected: tosa::TosaFullyConnectedAttribute* attribute; }; -template <DType Dtype, DType OutDtype> +template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype> class OpMatMul : public GraphNode { public: @@ -227,7 +227,7 @@ protected: tosa::TosaMatMulAttribute* attribute; }; -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> class OpMaxPool2d : public GraphNode { public: @@ -248,7 +248,7 @@ protected: tosa::TosaPoolAttribute* attribute; }; -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> class OpFFT2d : public GraphNode { public: @@ -271,7 +271,7 @@ protected: tosa::TosaFFTAttribute* attribute; }; -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> class OpRFFT2d : public GraphNode { public: @@ -292,7 +292,7 @@ protected: TosaReference::TensorTemplate<TOut>* out_imag; }; -template <DType InDtype, DType WeightDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> class OpTransposeConv2d : public GraphNode { public: diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc index 9034add..68ffb1f 100644 --- a/reference_model/src/ops/type_conversion.cc +++ b/reference_model/src/ops/type_conversion.cc @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -35,14 +35,14 @@ OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Rescale); } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> OpRescale<Rank, InDtype, OutDtype>::~OpRescale() { if (attribute) delete attribute; } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes() { // Check Tosa Level @@ -69,31 +69,33 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes() ASSERT_MEM(in && out); - if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0)) + if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) && + (attribute->input_zp() != 0)) { - printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0"); + printNodeValidationError("OpRescale: Input TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0"); return 1; } - if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0)) + if ((OutDtype != TOSA_REF_TYPE_INT8) && (OutDtype != TOSA_REF_TYPE_UINT8) && (OutDtype != TOSA_REF_TYPE_UINT16) && + (attribute->output_zp() != 0)) { - printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0"); + printNodeValidationError("OpRescale: Output TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0"); return 1; } - if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768))) + if ((InDtype == TOSA_REF_TYPE_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768))) { - printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768"); + printNodeValidationError("OpRescale: Input TOSA_REF_TYPE UINT16 and zero point not 0 or 32768"); return 1; } - if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768))) + if ((OutDtype == TOSA_REF_TYPE_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768))) { - printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768"); + printNodeValidationError("OpRescale: Output TOSA_REF_TYPE UINT16 and zero point not 0 or 32768"); return 1; } - if (attribute->scale32() && (InDtype == DType_INT48)) + if (attribute->scale32() && (InDtype == TOSA_REF_TYPE_INT48)) { printNodeValidationError("OpRescale: Scale set to true but input type is INT48"); return 1; @@ -108,7 +110,7 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes() return 0; } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int OpRescale<Rank, InDtype, OutDtype>::eval() { int32_t input_zp = attribute->input_zp(); @@ -237,7 +239,7 @@ int OpRescale<Rank, InDtype, OutDtype>::eval() return GraphNode::eval(); } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -247,11 +249,11 @@ OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_, setRequiredRank(0, 6); } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> OpCast<Rank, InDtype, OutDtype>::~OpCast() {} -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes() { // Check Tosa Level @@ -281,7 +283,7 @@ int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes() return 0; } -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> int OpCast<Rank, InDtype, OutDtype>::eval() { this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn()); @@ -289,7 +291,7 @@ int OpCast<Rank, InDtype, OutDtype>::eval() return GraphNode::eval(); } -template <DType InDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> CastHelper<InDtype, OutDtype>::CastHelper() { fcn = [](InEigenType in) -> OutEigenType { @@ -298,14 +300,14 @@ CastHelper<InDtype, OutDtype>::CastHelper() }; } -template <DType InDtype> -CastHelper<InDtype, DType_BOOL>::CastHelper() +template <TOSA_REF_TYPE InDtype> +CastHelper<InDtype, TOSA_REF_TYPE_BOOL>::CastHelper() { fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; }; } -template <DType OutDtype> -CastHelper<DType_BOOL, OutDtype>::CastHelper() +template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_BOOL, OutDtype>::CastHelper() { fcn = [](bool in) -> OutEigenType { OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0; @@ -313,8 +315,8 @@ CastHelper<DType_BOOL, OutDtype>::CastHelper() }; } -template <DType InDtype> -CastHelper<InDtype, DType_FP16>::CastHelper() +template <TOSA_REF_TYPE InDtype> +CastHelper<InDtype, TOSA_REF_TYPE_FP16>::CastHelper() { // Integer data converted to fp16 (stored as fp32) fcn = [](InEigenType in) -> float { @@ -324,17 +326,17 @@ CastHelper<InDtype, DType_FP16>::CastHelper() }; } -CastHelper<DType_FP32, DType_FP16>::CastHelper() +CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP16>::CastHelper() { // fp32 data converted to fp16 (stored as fp32) fcn = [](float in) -> float { - float out = fpTrunc<DType_FP16>(in); // truncate required for conversion from higher precision + float out = fpTrunc<TOSA_REF_TYPE_FP16>(in); // truncate required for conversion from higher precision return out; }; } -template <DType InDtype> -CastHelper<InDtype, DType_BF16>::CastHelper() +template <TOSA_REF_TYPE InDtype> +CastHelper<InDtype, TOSA_REF_TYPE_BF16>::CastHelper() { // Integer data converted to bf16 (stored as fp32) fcn = [](InEigenType in) -> float { @@ -343,16 +345,16 @@ CastHelper<InDtype, DType_BF16>::CastHelper() }; } -CastHelper<DType_FP32, DType_BF16>::CastHelper() +CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_BF16>::CastHelper() { // fp32 data converted to bf16 (stored as fp32) fcn = [](float in) -> float { - return fpTrunc<DType_BF16>(in); // truncate required for conversions from higher precision + return fpTrunc<TOSA_REF_TYPE_BF16>(in); // truncate required for conversions from higher precision }; } -template <DType OutDtype> -CastHelper<DType_FP16, OutDtype>::CastHelper() +template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_FP16, OutDtype>::CastHelper() { // fp16 data (stored as fp32) converted to integer fcn = [](float in) -> OutEigenType { @@ -366,7 +368,7 @@ CastHelper<DType_FP16, OutDtype>::CastHelper() }; } -CastHelper<DType_FP16, DType_FP32>::CastHelper() +CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP32>::CastHelper() { // No-op since fp16 values treated internally as their fp32 representation fcn = [](float in) -> OutEigenType { @@ -374,8 +376,8 @@ CastHelper<DType_FP16, DType_FP32>::CastHelper() }; } -template <DType OutDtype> -CastHelper<DType_BF16, OutDtype>::CastHelper() +template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_BF16, OutDtype>::CastHelper() { // bf16 data (stored as fp32) converted to integer fcn = [](float in) -> OutEigenType { @@ -386,7 +388,7 @@ CastHelper<DType_BF16, OutDtype>::CastHelper() }; } -CastHelper<DType_BF16, DType_FP32>::CastHelper() +CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP32>::CastHelper() { // No-op since bf16 values treated as truncated fp32 internally fcn = [](InEigenType in) -> OutEigenType { @@ -394,8 +396,8 @@ CastHelper<DType_BF16, DType_FP32>::CastHelper() }; } -template <DType InDtype> -CastHelper<InDtype, DType_FP32>::CastHelper() +template <TOSA_REF_TYPE InDtype> +CastHelper<InDtype, TOSA_REF_TYPE_FP32>::CastHelper() { // Integer data converted to fp32 fcn = [](InEigenType in) -> float { @@ -404,8 +406,8 @@ CastHelper<InDtype, DType_FP32>::CastHelper() }; } -template <DType OutDtype> -CastHelper<DType_FP32, OutDtype>::CastHelper() +template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper() { // fp32 data converted to integer fcn = [](float in) -> OutEigenType { @@ -416,6 +418,31 @@ CastHelper<DType_FP32, OutDtype>::CastHelper() }; } +template <TOSA_REF_TYPE OutDtype> +CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper() +{ + switch (OutDtype) + { + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: + // fp64 data converted to integer + fcn = [](InEigenType in) -> OutEigenType { + OutEigenType out = std::rint(in); + out = std::max<OutEigenType>(out, OutMin); + out = std::min<OutEigenType>(out, OutMax); + return out; + }; + break; + case TOSA_REF_TYPE_FP64: + // no op + fcn = [](InEigenType in) -> OutEigenType { return in; }; + break; + default: + ASSERT_MSG(false, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(OutDtype)); + } +} + // template explicit instantiation DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16); @@ -451,6 +478,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8); DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h index e2fc6e2..98799a0 100644 --- a/reference_model/src/ops/type_conversion.h +++ b/reference_model/src/ops/type_conversion.h @@ -1,5 +1,5 @@ -// Copyright (c) 2020-2022, ARM Limited. +// Copyright (c) 2020-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ using namespace tosa; namespace TosaReference { -template <int Rank, DType InDtype, DType OutDtype> +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> class OpRescale : public GraphNode { public: @@ -46,7 +46,7 @@ protected: TosaReference::TensorTemplate<TOut>* out; }; -template <DType InDtype, DType OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> class CastHelper { public: @@ -64,12 +64,12 @@ private: FcnType fcn; }; -template <DType InDtype> -class CastHelper<InDtype, DType_BOOL> +template <TOSA_REF_TYPE InDtype> +class CastHelper<InDtype, TOSA_REF_TYPE_BOOL> { public: using InEigenType = typename GetEigenType<InDtype>::type; - using OutEigenType = typename GetEigenType<DType_BOOL>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type; using FcnType = std::function<OutEigenType(InEigenType)>; CastHelper(); const FcnType& get_fcn() const @@ -81,11 +81,11 @@ private: FcnType fcn; }; -template <DType OutDtype> -class CastHelper<DType_BOOL, OutDtype> +template <TOSA_REF_TYPE OutDtype> +class CastHelper<TOSA_REF_TYPE_BOOL, OutDtype> { public: - using InEigenType = typename GetEigenType<DType_BOOL>::type; + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type; using OutEigenType = typename GetEigenType<OutDtype>::type; using FcnType = std::function<OutEigenType(InEigenType)>; static constexpr int32_t OutMin = GetQMin<OutDtype>::value; @@ -100,12 +100,12 @@ private: FcnType fcn; }; -template <DType InDtype> -class CastHelper<InDtype, DType_FP16> +template <TOSA_REF_TYPE InDtype> +class CastHelper<InDtype, TOSA_REF_TYPE_FP16> { public: using InEigenType = typename GetEigenType<InDtype>::type; - using OutEigenType = typename GetEigenType<DType_FP16>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type; using FcnType = std::function<OutEigenType(InEigenType)>; CastHelper(); const FcnType& get_fcn() const @@ -117,11 +117,11 @@ private: FcnType fcn; }; -template <DType OutDtype> -class CastHelper<DType_FP16, OutDtype> +template <TOSA_REF_TYPE OutDtype> +class CastHelper<TOSA_REF_TYPE_FP16, OutDtype> { public: - using InEigenType = typename GetEigenType<DType_FP16>::type; + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type; using OutEigenType = typename GetEigenType<OutDtype>::type; using FcnType = std::function<OutEigenType(InEigenType)>; static constexpr int32_t OutMin = GetQMin<OutDtype>::value; @@ -137,11 +137,11 @@ private: }; template <> -class CastHelper<DType_FP32, DType_FP16> +class CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP16> { public: - using InEigenType = typename GetEigenType<DType_FP32>::type; - using OutEigenType = typename GetEigenType<DType_FP16>::type; + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type; using FcnType = std::function<OutEigenType(InEigenType)>; CastHelper(); const FcnType& get_fcn() const @@ -153,12 +153,12 @@ private: FcnType fcn; }; -template <DType InDtype> -class CastHelper<InDtype, DType_BF16> +template <TOSA_REF_TYPE InDtype> +class CastHelper<InDtype, TOSA_REF_TYPE_BF16> { public: using InEigenType = typename GetEigenType<InDtype>::type; - using OutEigenType = typename GetEigenType<DType_BF16>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type; using FcnType = std::function<OutEigenType(InEigenType)>; CastHelper(); const FcnType& get_fcn() const @@ -170,11 +170,11 @@ private: FcnType fcn; }; -template <DType OutDtype> -class CastHelper<DType_BF16, OutDtype> +template <TOSA_REF_TYPE OutDtype> +class CastHelper<TOSA_REF_TYPE_BF16, OutDtype> { public: - using InEigenType = typename GetEigenType<DType_BF16>::type; + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type; using OutEigenType = typename GetEigenType<OutDtype>::type; using FcnType = std::function<OutEigenType(InEigenType)>; static constexpr int32_t OutMin = GetQMin<OutDtype>::value; @@ -190,11 +190,11 @@ private: }; template <> -class CastHelper<DType_FP32, DType_BF16> +class CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_BF16> { public: - using InEigenType = typename GetEigenType<DType_FP32>::type; - using OutEigenType = typename GetEigenType<DType_BF16>::type; + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type; using FcnType = std::function<OutEigenType(InEigenType)>; CastHelper(); const FcnType& get_fcn() const @@ -206,12 +206,12 @@ private: FcnType fcn; }; -template <DType InDtype> -class CastHelper<InDtype, DType_FP32> +template <TOSA_REF_TYPE InDtype> +class CastHelper<InDtype, TOSA_REF_TYPE_FP32> { public: using InEigenType = typename GetEigenType<InDtype>::type; - using OutEigenType = typename GetEigenType<DType_FP32>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type; using FcnType = std::function<OutEigenType(InEigenType)>; CastHelper(); const FcnType& get_fcn() const @@ -224,11 +224,11 @@ private: }; template <> -class CastHelper<DType_FP16, DType_FP32> +class CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP32> { public: - using InEigenType = typename GetEigenType<DType_FP16>::type; - using OutEigenType = typename GetEigenType<DType_FP32>::type; + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type; using FcnType = std::function<OutEigenType(InEigenType)>; CastHelper(); const FcnType& get_fcn() const @@ -241,11 +241,11 @@ private: }; template <> -class CastHelper<DType_BF16, DType_FP32> +class CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP32> { public: - using InEigenType = typename GetEigenType<DType_BF16>::type; - using OutEigenType = typename GetEigenType<DType_FP32>::type; + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type; + using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type; using FcnType = std::function<OutEigenType(InEigenType)>; CastHelper(); const FcnType& get_fcn() const @@ -257,11 +257,11 @@ private: FcnType fcn; }; -template <DType OutDtype> -class CastHelper<DType_FP32, OutDtype> +template <TOSA_REF_TYPE OutDtype> +class CastHelper<TOSA_REF_TYPE_FP32, OutDtype> { public: - using InEigenType = typename GetEigenType<DType_FP32>::type; + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type; using OutEigenType = typename GetEigenType<OutDtype>::type; using FcnType = std::function<OutEigenType(InEigenType)>; static constexpr int32_t OutMin = GetQMin<OutDtype>::value; @@ -276,7 +276,26 @@ private: FcnType fcn; }; -template <int Rank, DType InDtype, DType OutDtype> +template <TOSA_REF_TYPE OutDtype> +class CastHelper<TOSA_REF_TYPE_FP64, OutDtype> +{ +public: + using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP64>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + static constexpr int32_t OutMin = GetQMin<OutDtype>::value; + static constexpr int32_t OutMax = GetQMax<OutDtype>::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype> class OpCast : public GraphNode { public: |