aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-03-28 22:06:56 +0000
committerTai Ly <tai.ly@arm.com>2023-05-05 19:23:15 +0000
commita4d748b08accce06fab93e2d2b96e499b35ae89b (patch)
tree20a3957e1f45f65f35d5d67ecce1618659e388f0 /reference_model/src/ops
parent0c71686875618b2e11290273b7a05b88ef8a8aae (diff)
downloadreference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model, which will cause reference model to convert all floating point tensors to FP64 tensors and compute all operators accordingly. Also adds optional -p arguments to test runners tosa_verif_run_tests.py and tosa_verif_framework_compiler_runner.py to run tests in precise mode Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
Diffstat (limited to 'reference_model/src/ops')
-rw-r--r--reference_model/src/ops/activation_funcs.cc56
-rw-r--r--reference_model/src/ops/activation_funcs.h8
-rw-r--r--reference_model/src/ops/comparison.cc44
-rw-r--r--reference_model/src/ops/comparison.h26
-rw-r--r--reference_model/src/ops/control_flow.cc21
-rw-r--r--reference_model/src/ops/data_layout.cc93
-rw-r--r--reference_model/src/ops/data_layout.h22
-rw-r--r--reference_model/src/ops/data_nodes.cc11
-rw-r--r--reference_model/src/ops/data_nodes.h4
-rw-r--r--reference_model/src/ops/ewise_binary.cc210
-rw-r--r--reference_model/src/ops/ewise_binary.h36
-rw-r--r--reference_model/src/ops/ewise_ternary.cc18
-rw-r--r--reference_model/src/ops/ewise_ternary.h10
-rw-r--r--reference_model/src/ops/ewise_unary.cc164
-rw-r--r--reference_model/src/ops/ewise_unary.h10
-rw-r--r--reference_model/src/ops/image.cc55
-rw-r--r--reference_model/src/ops/image.h4
-rw-r--r--reference_model/src/ops/op_factory.cc60
-rw-r--r--reference_model/src/ops/op_factory.h82
-rw-r--r--reference_model/src/ops/reduction.cc111
-rw-r--r--reference_model/src/ops/reduction.h38
-rw-r--r--reference_model/src/ops/scatter_gather.cc20
-rw-r--r--reference_model/src/ops/scatter_gather.h6
-rw-r--r--reference_model/src/ops/template_types.h96
-rw-r--r--reference_model/src/ops/tensor_ops.cc203
-rw-r--r--reference_model/src/ops/tensor_ops.h24
-rw-r--r--reference_model/src/ops/type_conversion.cc116
-rw-r--r--reference_model/src/ops/type_conversion.h99
28 files changed, 998 insertions, 649 deletions
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index 24bd077..6681d6d 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpClamp<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
@@ -32,9 +32,9 @@ int OpClamp<Rank, Dtype>::register_fcn()
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
{
InEigenType min = (InEigenType)attribute->min_fp();
InEigenType max = (InEigenType)attribute->max_fp();
@@ -43,8 +43,17 @@ int OpClamp<Rank, Dtype>::register_fcn()
this->fcn = [min, max](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(a <= min ? min : a >= max ? max : a); };
}
break;
- case DType_INT8:
- case DType_INT16:
+ case TOSA_REF_TYPE_FP64:
+ {
+ InEigenType min = (InEigenType)attribute->min_fp();
+ InEigenType max = (InEigenType)attribute->max_fp();
+ ERROR_IF(max < min, "OpClamp: max smaller than min");
+
+ this->fcn = [min, max](InEigenType a) -> OutEigenType { return (a <= min ? min : a >= max ? max : a); };
+ }
+ break;
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
{
InEigenType min = (InEigenType)attribute->min_int();
InEigenType max = (InEigenType)attribute->max_int();
@@ -53,19 +62,19 @@ int OpClamp<Rank, Dtype>::register_fcn()
}
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpClamp<Rank, Dtype>::~OpClamp()
{
if (attribute) delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSigmoid<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
@@ -74,21 +83,24 @@ int OpSigmoid<Rank, Dtype>::register_fcn()
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType {
return fpTrunc<Dtype>(1.f / (1.f + (expf(-1.f * a))));
};
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return (1.L / (1.L + (exp(-1.L * a)))); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTanh<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
@@ -97,13 +109,16 @@ int OpTanh<Rank, Dtype>::register_fcn()
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(tanhf(a)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return tanh(a); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
@@ -115,11 +130,14 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP64);
diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h
index 9a697cd..2372fcb 100644
--- a/reference_model/src/ops/activation_funcs.h
+++ b/reference_model/src/ops/activation_funcs.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpClamp : public UnaryNode<Rank, Dtype>
{
public:
@@ -45,7 +45,7 @@ protected:
TosaClampAttribute* attribute;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpSigmoid : public UnaryNode<Rank, Dtype>
{
public:
@@ -61,7 +61,7 @@ public:
virtual int register_fcn();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpTanh : public UnaryNode<Rank, Dtype>
{
public:
diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc
index a5711eb..8a084c7 100644
--- a/reference_model/src/ops/comparison.cc
+++ b/reference_model/src/ops/comparison.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpEqual<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
@@ -31,20 +31,21 @@ int OpEqual<Rank, Dtype>::register_fcn()
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
- case DType_INT32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_INT32:
+ case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpGreater<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
@@ -53,20 +54,21 @@ int OpGreater<Rank, Dtype>::register_fcn()
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
- case DType_INT32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_INT32:
+ case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpGreaterEqual<Rank, Dtype>::register_fcn()
{
// Check Tosa Level
@@ -75,14 +77,15 @@ int OpGreaterEqual<Rank, Dtype>::register_fcn()
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
- case DType_INT32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_INT32:
+ case TOSA_REF_TYPE_FP64:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
@@ -93,13 +96,16 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP64);
diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h
index 29e6b5a..263df6a 100644
--- a/reference_model/src/ops/comparison.h
+++ b/reference_model/src/ops/comparison.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -24,45 +24,45 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType Dtype>
-class OpEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
+template <int Rank, TOSA_REF_TYPE Dtype>
+class OpEqual : public BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>
{
public:
OpEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
- : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, id_)
+ : BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>(sgt_, Op_EQUAL, id_)
{
register_fcn();
}
using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type;
virtual int register_fcn();
};
-template <int Rank, DType Dtype>
-class OpGreater : public BinaryNode<Rank, Dtype, DType_BOOL>
+template <int Rank, TOSA_REF_TYPE Dtype>
+class OpGreater : public BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>
{
public:
OpGreater(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
- : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_GREATER, id_)
+ : BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>(sgt_, Op_GREATER, id_)
{
register_fcn();
}
using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type;
virtual int register_fcn();
};
-template <int Rank, DType Dtype>
-class OpGreaterEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
+template <int Rank, TOSA_REF_TYPE Dtype>
+class OpGreaterEqual : public BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>
{
public:
OpGreaterEqual(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
- : BinaryNode<Rank, Dtype, DType_BOOL>(sgt_, Op_EQUAL, id_)
+ : BinaryNode<Rank, Dtype, TOSA_REF_TYPE_BOOL>(sgt_, Op_EQUAL, id_)
{
register_fcn();
}
using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type;
virtual int register_fcn();
};
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
index f573d5b..03ad6c6 100644
--- a/reference_model/src/ops/control_flow.cc
+++ b/reference_model/src/ops/control_flow.cc
@@ -174,8 +174,8 @@ int OpCondIf::checkTensorAttributes()
{
ERROR_IF(getInputs().size() < 1, "OpCondIf: must have at least 1 operand");
- ERROR_IF(inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0,
- "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()],
+ ERROR_IF(inputs[0]->getDtype() != TOSA_REF_TYPE_BOOL || inputs[0]->getRank() != 0,
+ "OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNameTOSAREFTYPE(inputs[0]->getDtype()),
inputs[0]->getRank());
cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
@@ -223,9 +223,9 @@ int OpCondIf::checkTensorAttributes()
std::string else_block_input_name = else_block->GetInputs()[i];
TosaSerializationTensor* then_block_input = then_block->GetTensorByName(then_block_input_name);
TosaSerializationTensor* else_block_input = else_block->GetTensorByName(else_block_input_name);
- ERROR_IF(operator_input->getDtype() != then_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(then_block_input->GetDtype()),
"OpCondIf: input tensor type mismatch with then_block input type");
- ERROR_IF(operator_input->getDtype() != else_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(else_block_input->GetDtype()),
"OpCondIf: input tensor type mismatch with else_block input type");
ERROR_IF(operator_input->getRank() != (int32_t)then_block_input->GetShape().size(),
"OpCondIf: input tensor rank mismatch with then_block input rank");
@@ -247,9 +247,9 @@ int OpCondIf::checkTensorAttributes()
std::string else_block_output_name = else_block->GetOutputs()[i];
TosaSerializationTensor* then_block_output = then_block->GetTensorByName(then_block_output_name);
TosaSerializationTensor* else_block_output = else_block->GetTensorByName(else_block_output_name);
- ERROR_IF(operator_output->getDtype() != then_block_output->GetDtype(),
+ ERROR_IF(operator_output->getDtype() != ConvertDType(then_block_output->GetDtype()),
"OpCondIf: output tensor type mismatch with then_block output type");
- ERROR_IF(operator_output->getDtype() != else_block_output->GetDtype(),
+ ERROR_IF(operator_output->getDtype() != ConvertDType(else_block_output->GetDtype()),
"OpCondIf: output tensor type mismatch with else_block output type");
ERROR_IF(operator_output->getRank() != (int32_t)then_block_output->GetShape().size(),
"OpCondIf: output tensor rank mismatch with then_block output rank");
@@ -364,11 +364,11 @@ int OpWhileLoop::checkTensorAttributes()
TosaSerializationTensor* body_block_input = body_block->GetTensorByName(body_block_input_name);
TosaSerializationTensor* body_block_output = body_block->GetTensorByName(body_block_output_name);
- ERROR_IF(operator_input->getDtype() != cond_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(cond_block_input->GetDtype()),
"OpWhileLoop: input tensor type mismatch with cond_block input type");
- ERROR_IF(operator_input->getDtype() != body_block_input->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_input->GetDtype()),
"OpWhileLoop: input tensor type mismatch with body_block input type");
- ERROR_IF(operator_input->getDtype() != body_block_output->GetDtype(),
+ ERROR_IF(operator_input->getDtype() != ConvertDType(body_block_output->GetDtype()),
"OpWhileLoop: input tensor type mismatch with body_block output type");
ERROR_IF(operator_input->getRank() != (int32_t)cond_block_input->GetShape().size(),
"OpWhileLoop: input tensor rank mismatch with cond_block input rank");
@@ -399,8 +399,7 @@ int OpWhileLoop::checkTensorAttributes()
int OpWhileLoop::eval()
{
-
- TosaReference::Tensor0<bool> cond_output_ctensor(std::string("cond_output"), DType_BOOL, std::vector<int32_t>({}));
+ TosaReference::Tensor0<bool> cond_output_ctensor("cond_output", DType_BOOL, std::vector<int32_t>({}));
cond_output_ctensor.allocate();
std::vector<TosaReference::Tensor*> cond_block_outputs;
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index a189466..442cef8 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -20,7 +20,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -32,14 +32,14 @@ OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Axis);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpConcat<Rank, Dtype>::~OpConcat()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpConcat<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -100,7 +100,7 @@ int OpConcat<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpConcat<Rank, Dtype>::eval()
{
@@ -124,7 +124,7 @@ int OpConcat<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -136,12 +136,12 @@ OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Pad);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpPad<Rank, Dtype>::~OpPad()
{
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpPad<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -185,22 +185,23 @@ int OpPad<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpPad<Rank, Dtype>::eval()
{
InEigenType pad_value = 0;
switch (Dtype)
{
- case DType_BOOL:
- case DType_INT8:
- case DType_INT16:
- case DType_INT32:
+ case TOSA_REF_TYPE_BOOL:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
pad_value = (InEigenType)attribute->pad_const_int();
break;
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_FP64:
pad_value = (InEigenType)attribute->pad_const_fp();
break;
default:
@@ -213,7 +214,7 @@ int OpPad<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -225,14 +226,14 @@ OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Reshape);
}
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
OpReshape<InRank, OutRank, Dtype>::~OpReshape()
{
if (attribute)
delete attribute;
}
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -270,7 +271,7 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
int OpReshape<InRank, OutRank, Dtype>::eval()
{
for (int32_t d = 0; d < OutRank; d++)
@@ -313,7 +314,7 @@ int OpReshape<InRank, OutRank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -325,14 +326,14 @@ OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Axis);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpReverse<Rank, Dtype>::~OpReverse()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReverse<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -376,7 +377,7 @@ int OpReverse<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReverse<Rank, Dtype>::eval()
{
out->getTensor() = in->getTensor().reverse(reverse_array);
@@ -384,7 +385,7 @@ int OpReverse<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -396,14 +397,14 @@ OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Slice);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpSlice<Rank, Dtype>::~OpSlice()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSlice<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -449,7 +450,7 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSlice<Rank, Dtype>::eval()
{
out->getTensor() = in->getTensor().slice(begin_array, size_array);
@@ -457,7 +458,7 @@ int OpSlice<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -469,14 +470,14 @@ OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Tile);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpTileBase<Rank, Dtype>::~OpTileBase()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTileBase<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -518,14 +519,14 @@ int OpTileBase<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTile<Rank, Dtype>::eval()
{
// primary template shouldn't be called
- FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
+ FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNameTOSAREFTYPE(Dtype));
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<1, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -537,7 +538,7 @@ int OpTile<1, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<2, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -553,7 +554,7 @@ int OpTile<2, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<3, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -573,7 +574,7 @@ int OpTile<3, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<4, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -597,7 +598,7 @@ int OpTile<4, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<5, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -626,7 +627,7 @@ int OpTile<5, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<6, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -659,7 +660,7 @@ int OpTile<6, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -671,13 +672,13 @@ OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Transpose);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpTranspose<Rank, Dtype>::~OpTranspose()
{
if (attribute) delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTranspose<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -727,7 +728,7 @@ int OpTranspose<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTranspose<Rank, Dtype>::eval()
{
out->getTensor() = in->getTensor().shuffle(perm_array);
@@ -743,6 +744,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
@@ -751,6 +753,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
DEF_INSTANTIATE_RESHAPE(OpReshape, BF16);
@@ -759,6 +762,7 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
+DEF_INSTANTIATE_RESHAPE(OpReshape, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
@@ -767,6 +771,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
@@ -775,6 +780,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
@@ -783,6 +789,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
@@ -791,6 +798,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
@@ -799,3 +807,4 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64);
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
index 3a6cb0d..94ce248 100644
--- a/reference_model/src/ops/data_layout.h
+++ b/reference_model/src/ops/data_layout.h
@@ -23,7 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpConcat : public GraphNode
{
public:
@@ -45,7 +45,7 @@ protected:
TosaReference::TensorTemplate<TOut>* out;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpPad : public GraphNode
{
public:
@@ -66,7 +66,7 @@ protected:
TosaPadAttribute* attribute;
};
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
class OpReshape : public GraphNode
{
public:
@@ -90,7 +90,7 @@ protected:
TosaReference::TensorTemplate<TOut>* out;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReverse : public GraphNode
{
public:
@@ -112,7 +112,7 @@ protected:
Eigen::array<bool, Rank> reverse_array;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpSlice : public GraphNode
{
public:
@@ -135,7 +135,7 @@ protected:
TosaReference::TensorTemplate<TOut>* out;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpTileBase : public GraphNode
{
public:
@@ -156,7 +156,7 @@ protected:
};
// primary template for op tile
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpTile : public OpTileBase<Rank, Dtype>
{
public:
@@ -170,12 +170,12 @@ protected:
// partial specialization for specific rank
#define DEF_OP_TILE_RANK(N) \
- template <DType Dtype> \
+ template <TOSA_REF_TYPE Dtype> \
class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \
{ \
public: \
- OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
- : OpTileBase<N, Dtype>(sgt_, attribute_, id_) \
+ OpTile(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
+ : OpTileBase<N, Dtype>(sgt_, attribute_, id_) \
{} \
\
protected: \
@@ -191,7 +191,7 @@ DEF_OP_TILE_RANK(6)
#undef DEF_OP_TILE_RANK
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpTranspose : public GraphNode
{
public:
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
index f5304a5..b7f987a 100644
--- a/reference_model/src/ops/data_nodes.cc
+++ b/reference_model/src/ops/data_nodes.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -42,7 +42,7 @@ int OpConst::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpIdentity<Rank, Dtype>::OpIdentity(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -52,11 +52,11 @@ OpIdentity<Rank, Dtype>::OpIdentity(SubgraphTraverser* sgt_,
setRequiredRank(0, 6);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpIdentity<Rank, Dtype>::~OpIdentity()
{}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpIdentity<Rank, Dtype>::checkTensorAttributes()
{
@@ -78,7 +78,7 @@ int OpIdentity<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpIdentity<Rank, Dtype>::eval()
{
out->getTensor() = in->getTensor();
@@ -96,3 +96,4 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64);
diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h
index 8761a08..395c667 100644
--- a/reference_model/src/ops/data_nodes.h
+++ b/reference_model/src/ops/data_nodes.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -31,7 +31,7 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpIdentity : public GraphNode
{
public:
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 6aa0c0f..c5801e7 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -22,7 +22,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
const Op& op_,
uint64_t id_)
@@ -37,11 +37,11 @@ BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(SubgraphTraverser* sgt_,
fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
{}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -90,7 +90,7 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
{
const std::vector<int>& a_shape = a->getShape();
@@ -106,7 +106,7 @@ int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int BinaryNode<Rank, InDtype, OutDtype>::eval()
{
this->broadcast();
@@ -124,7 +124,7 @@ int BinaryNode<Rank, InDtype, OutDtype>::eval()
}
// still need to partial specialize this, or Eigen will throw static assertion
-template <DType InDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int BinaryNode<0, InDtype, OutDtype>::eval()
{
this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
@@ -132,12 +132,12 @@ int BinaryNode<0, InDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpAdd<Rank, Dtype>::register_fcn()
{
switch (InDtype)
{
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
int64_t res_in_64 = static_cast<int64_t>(a) + b;
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
@@ -146,36 +146,39 @@ int OpAdd<Rank, Dtype>::register_fcn()
return static_cast<InEigenType>(res_in_64);
};
break;
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a + b); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
{
bool round = attribute->round();
int32_t num_bits = 0;
switch (Dtype)
{
- case DType_INT8:
+ case TOSA_REF_TYPE_INT8:
num_bits = 8;
break;
- case DType_INT16:
+ case TOSA_REF_TYPE_INT16:
num_bits = 16;
break;
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
num_bits = 32;
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
this->fcn = [this, round, num_bits](InEigenType a, InEigenType b) -> OutEigenType {
@@ -195,69 +198,69 @@ int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpArithmeticRightShift<Rank, Dtype>::~OpArithmeticRightShift()
{
if (attribute) delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpBitwiseAnd<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_INT8:
- case DType_INT16:
- case DType_INT32:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpBitwiseOr<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_INT8:
- case DType_INT16:
- case DType_INT32:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpBitwiseXor<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_INT8:
- case DType_INT16:
- case DType_INT32:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpIntdiv<Rank, Dtype>::register_fcn()
{
switch (InDtype)
{
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b != 0, "OpIntDiv: divisor must be non-zero value");
int64_t res_in_64 = static_cast<int64_t>(a) / b;
@@ -268,47 +271,47 @@ int OpIntdiv<Rank, Dtype>::register_fcn()
};
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLogicalAnd<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_BOOL:
+ case TOSA_REF_TYPE_BOOL:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_INT8:
+ case TOSA_REF_TYPE_INT8:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
(int32_t)b);
return static_cast<OutEigenType>(static_cast<int8_t>(a << b));
};
break;
- case DType_INT16:
+ case TOSA_REF_TYPE_INT16:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
(int32_t)b);
return static_cast<OutEigenType>(static_cast<int16_t>(a << b));
};
break;
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b >= 0 && b <= 31, "OpLogicalLeftShift: shift value %d is out of valid range [0, 31]",
(int32_t)b);
@@ -316,32 +319,32 @@ int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
};
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLogicalRightShift<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_INT8:
+ case TOSA_REF_TYPE_INT8:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
(int32_t)b);
return static_cast<OutEigenType>(static_cast<int8_t>(a) >> b);
};
break;
- case DType_INT16:
+ case TOSA_REF_TYPE_INT16:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
(int32_t)b);
return static_cast<OutEigenType>(static_cast<int16_t>(a) >> b);
};
break;
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
REQUIRE(b >= 0 && b <= 31, "OpLogicalRightShift: shift value %d is out of valid range [0, 31]",
(int32_t)b);
@@ -349,91 +352,96 @@ int OpLogicalRightShift<Rank, Dtype>::register_fcn()
};
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLogicalOr<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_BOOL:
+ case TOSA_REF_TYPE_BOOL:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLogicalXor<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_BOOL:
+ case TOSA_REF_TYPE_BOOL:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpMaximum<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
- case DType_INT32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_FP64:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpMinimum<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
- case DType_INT32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_FP64:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpMul<Rank, InDtype, OutDtype>::register_fcn()
{
int32_t shift = attribute->shift();
switch (InDtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a * b); };
break;
- case DType_INT32:
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+ break;
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this, shift](InEigenType a, InEigenType b) -> OutEigenType {
int64_t result;
if (shift > 0)
@@ -457,8 +465,8 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
return static_cast<OutEigenType>(result);
};
break;
- case DType_INT8:
- case DType_INT16:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
@@ -468,41 +476,44 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
};
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
}
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpMul<Rank, InDtype, OutDtype>::~OpMul()
{
if (attribute) delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpPow<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(powf(a, b)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return pow(a, b); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSub<Rank, Dtype>::register_fcn()
{
switch (InDtype)
{
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this](InEigenType a, InEigenType b) -> OutEigenType {
int64_t res_in_64 = static_cast<int64_t>(a) - b;
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<InEigenType>::max());
@@ -511,19 +522,22 @@ int OpSub<Rank, Dtype>::register_fcn()
return static_cast<InEigenType>(res_in_64);
};
break;
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return fpTrunc<OutDtype>(a - b); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
}
return 0;
}
-template <int Rank, DType InDtype>
+template <int Rank, TOSA_REF_TYPE InDtype>
OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -535,13 +549,13 @@ OpTable<Rank, InDtype>::OpTable(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Table);
}
-template <int Rank, DType InDtype>
+template <int Rank, TOSA_REF_TYPE InDtype>
OpTable<Rank, InDtype>::~OpTable()
{
if (attribute) delete attribute;
}
-template <int Rank, DType InDtype>
+template <int Rank, TOSA_REF_TYPE InDtype>
int OpTable<Rank, InDtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -573,12 +587,12 @@ int OpTable<Rank, InDtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType InDtype>
+template <int Rank, TOSA_REF_TYPE InDtype>
int OpTable<Rank, InDtype>::eval()
{
switch (InDtype)
{
- case DType_INT8:
+ case TOSA_REF_TYPE_INT8:
this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
int32_t index = input_truncated - QInMin;
@@ -587,7 +601,7 @@ int OpTable<Rank, InDtype>::eval()
return value;
});
break;
- case DType_INT16:
+ case TOSA_REF_TYPE_INT16:
this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
// 1. make sure input is int16 range
int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
@@ -610,7 +624,7 @@ int OpTable<Rank, InDtype>::eval()
});
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[InDtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(InDtype));
}
return GraphNode::eval();
@@ -630,11 +644,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, BF16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP32, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, INT32, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNodeBase, FP64, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
@@ -672,11 +688,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, BF16, BF16);
@@ -684,15 +702,18 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP64, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
@@ -703,3 +724,4 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, BF16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP64, BOOL);
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index 020ddb5..5f6e531 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -38,7 +38,7 @@ namespace TosaReference
// the way of registering lambda + .binaryExpr() might sacrifice performance here
// but it can avoid partially specialization for combination of {rankN, rank0} x {FP32/INT32, QU8, ...}
// needs to revisit if performance becomes a bottleneck here
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class BinaryNodeBase : public GraphNode
{
public:
@@ -67,7 +67,7 @@ protected:
};
// primary class
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype>
{
public:
@@ -86,7 +86,7 @@ public:
};
// partial specialization for rank 0
-template <DType InDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype>
{
public:
@@ -100,19 +100,19 @@ public:
};
#define DEF_TEMPLATE_BINARY_OP_DEFAULT(Opname, OPNAME) \
- template <int Rank, DType Dtype> \
+ template <int Rank, TOSA_REF_TYPE Dtype> \
class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \
{ \
public: \
- Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
- : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_) \
+ Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
+ : BinaryNode<Rank, Dtype, Dtype>(sgt_, Op_##OPNAME, id_) \
{ \
register_fcn(); \
} \
- static constexpr DType InDtype = Dtype; \
- static constexpr DType OutDtype = Dtype; \
- using InEigenType = typename GetEigenType<InDtype>::type; \
- using OutEigenType = typename GetEigenType<OutDtype>::type; \
+ static constexpr TOSA_REF_TYPE InDtype = Dtype; \
+ static constexpr TOSA_REF_TYPE OutDtype = Dtype; \
+ using InEigenType = typename GetEigenType<InDtype>::type; \
+ using OutEigenType = typename GetEigenType<OutDtype>::type; \
virtual int register_fcn(); \
};
@@ -133,7 +133,7 @@ DEF_TEMPLATE_BINARY_OP_DEFAULT(Sub, SUB)
#undef DEF_TEMPLATE_BINARY_OP_DEFAULT
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpArithmeticRightShift : public BinaryNode<Rank, Dtype, Dtype>
{
public:
@@ -154,7 +154,7 @@ protected:
TosaArithmeticRightShiftAttribute* attribute;
};
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class OpMul : public BinaryNode<Rank, InDtype, OutDtype>
{
public:
@@ -175,7 +175,7 @@ protected:
TosaMulAttribute* attribute;
};
-template <int Rank, DType InDtype>
+template <int Rank, TOSA_REF_TYPE InDtype>
class OpTable : public GraphNode
{
public:
@@ -185,9 +185,11 @@ public:
virtual int checkTensorAttributes();
virtual int eval();
- static constexpr DType TableDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT16;
- static constexpr DType OutDtype = (InDtype == DType_INT8) ? DType_INT8 : DType_INT32;
- static constexpr uint32_t TableNumEntries = (InDtype == DType_INT8) ? 256 : 513;
+ static constexpr TOSA_REF_TYPE TableDtype =
+ (InDtype == TOSA_REF_TYPE_INT8) ? TOSA_REF_TYPE_INT8 : TOSA_REF_TYPE_INT16;
+ static constexpr TOSA_REF_TYPE OutDtype =
+ (InDtype == TOSA_REF_TYPE_INT8) ? TOSA_REF_TYPE_INT8 : TOSA_REF_TYPE_INT32;
+ static constexpr uint32_t TableNumEntries = (InDtype == TOSA_REF_TYPE_INT8) ? 256 : 513;
using InEigenType = typename GetEigenType<InDtype>::type;
using TableEigenType = typename GetEigenType<TableDtype>::type;
using OutEigenType = typename GetEigenType<OutDtype>::type;
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index 4d53ae4..090ce29 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -19,7 +19,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -29,11 +29,11 @@ OpSelectBase<Rank, Dtype>::OpSelectBase(SubgraphTraverser* sgt_,
setRequiredRank(0, 6);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpSelectBase<Rank, Dtype>::~OpSelectBase()
{}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -66,13 +66,13 @@ int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSelectBase<Rank, Dtype>::eval()
{
FATAL_ERROR("shouldn't be called");
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSelect<Rank, Dtype>::broadcast()
{
const std::vector<int>& cond_shape = this->cond->getShape();
@@ -90,7 +90,7 @@ int OpSelect<Rank, Dtype>::broadcast()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSelect<Rank, Dtype>::eval()
{
this->broadcast();
@@ -102,7 +102,7 @@ int OpSelect<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpSelect<0, Dtype>::eval()
{
this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor());
@@ -118,6 +118,7 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelectBase, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BF16);
@@ -126,3 +127,4 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP64);
diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h
index 75a2194..c6970cb 100644
--- a/reference_model/src/ops/ewise_ternary.h
+++ b/reference_model/src/ops/ewise_ternary.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -29,7 +29,7 @@ namespace TosaReference
// 3. Else_val: Rank N, type=<V>
// 4. Result: Rank N, type=<V>
// Cond, Then_val, Else_val need to be mutually-broadcastable
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpSelectBase : public GraphNode
{
public:
@@ -39,7 +39,7 @@ public:
virtual int checkTensorAttributes();
virtual int eval();
- using CondEigenType = typename GetEigenType<DType_BOOL>::type;
+ using CondEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type;
using InEigenType = typename GetEigenType<Dtype>::type;
using TCond = Eigen::Tensor<CondEigenType, Rank>;
using TIn = Eigen::Tensor<InEigenType, Rank>;
@@ -55,7 +55,7 @@ protected:
};
// primary class
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpSelect : public OpSelectBase<Rank, Dtype>
{
public:
@@ -69,7 +69,7 @@ public:
};
// partial specialization for rank 0
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype>
{
public:
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index 8dc37e2..514cb84 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64_t id_)
: GraphNode(sgt_, op_, id_)
{
@@ -35,11 +35,11 @@ UnaryNode<Rank, Dtype>::UnaryNode(SubgraphTraverser* sgt_, const Op& op_, uint64
};
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
UnaryNode<Rank, Dtype>::~UnaryNode()
{}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int UnaryNode<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -69,7 +69,7 @@ int UnaryNode<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int UnaryNode<Rank, Dtype>::eval()
{
this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn);
@@ -77,71 +77,75 @@ int UnaryNode<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpAbs<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP32: // No fpTrunc for FP32 as it is a no-op
- case DType_INT32:
+ case TOSA_REF_TYPE_FP32: // No fpTrunc for FP32 as it is a no-op
+ case TOSA_REF_TYPE_FP64:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
break;
- case DType_FP16:
- case DType_BF16:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(a > (InEigenType)0 ? a : (-a)); };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpBitwiseNot<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_INT8:
- case DType_INT16:
- case DType_INT32:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [](InEigenType a) -> OutEigenType { return ~a; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpCeil<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(ceilf(a)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return ceil(a); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpClz<Rank, Dtype>::register_fcn()
{
int32_t num_bits;
switch (Dtype)
{
- case DType_INT32:
+ case TOSA_REF_TYPE_INT32:
num_bits = 32;
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
this->fcn = [num_bits](int32_t a) -> int32_t {
@@ -163,73 +167,82 @@ int OpClz<Rank, Dtype>::register_fcn()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpExp<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(expf(a)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return exp(a); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpFloor<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(floorf(a)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return floor(a); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLog<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(logf(a)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return log(a); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpLogicalNot<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_BOOL:
+ case TOSA_REF_TYPE_BOOL:
this->fcn = [](InEigenType a) -> OutEigenType { return !a; };
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpNegate<Rank, Dtype>::OpNegate(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -240,31 +253,37 @@ OpNegate<Rank, Dtype>::OpNegate(SubgraphTraverser* sgt_,
register_fcn();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpNegate<Rank, Dtype>::~OpNegate()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpNegate<Rank, Dtype>::register_fcn()
{
- ERROR_IF(Dtype != DType_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t");
- ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input1_zp() != 0, "OpNegate: zeropoint only for int8_t");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0, "OpNegate: zeropoint only for int8_t");
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType {
InEigenType result = -(a);
return fpTrunc<Dtype>(result);
};
break;
- case DType_INT16:
- case DType_INT32:
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType {
+ OutEigenType result = -(a);
+ return result;
+ };
+ break;
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
this->fcn = [this](InEigenType a) -> OutEigenType {
int64_t res_in_64 = 0L - a;
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
@@ -272,7 +291,7 @@ int OpNegate<Rank, Dtype>::register_fcn()
REQUIRE(res_in_64 <= i32_max_in_64 && res_in_64 >= i32_min_in_64, "OpNegate: result not in acc type range (int32)");
int64_t max_clip_in_64, min_clip_in_64;
- if (Dtype == DType_INT16)
+ if (Dtype == TOSA_REF_TYPE_INT16)
{
max_clip_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::max());
min_clip_in_64 = static_cast<int64_t>(std::numeric_limits<int16_t>::min());
@@ -285,7 +304,7 @@ int OpNegate<Rank, Dtype>::register_fcn()
return static_cast<InEigenType>(std::min<int64_t>(max_clip_in_64, std::max<int64_t>(min_clip_in_64, res_in_64)));
};
break;
- case DType_INT8:
+ case TOSA_REF_TYPE_INT8:
this->fcn = [this](InEigenType a) -> OutEigenType {
int64_t res_in_64 = 0 - (a - attribute->input1_zp());
int64_t i32_max_in_64 = static_cast<int64_t>(std::numeric_limits<int32_t>::max());
@@ -297,41 +316,47 @@ int OpNegate<Rank, Dtype>::register_fcn()
};
break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReciprocal<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / a); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / a); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpRsqrt<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return fpTrunc<Dtype>(1.0 / sqrtf(a)); };
break;
+ case TOSA_REF_TYPE_FP64:
+ this->fcn = [](InEigenType a) -> OutEigenType { return (1.0L / sqrt(a)); };
+ break;
default:
- ERROR_IF(true, "unsupported DType %s", EnumNamesDType()[Dtype]);
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return 0;
@@ -345,11 +370,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(UnaryNode, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
@@ -358,20 +385,24 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
@@ -381,11 +412,14 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP64);
diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h
index 16a4c88..21ee276 100644
--- a/reference_model/src/ops/ewise_unary.h
+++ b/reference_model/src/ops/ewise_unary.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class UnaryNode : public GraphNode
{
public:
@@ -45,11 +45,11 @@ protected:
};
#define DEF_TEMPLATE_UNARY_OP(Opname, OPNAME) \
- template <int Rank, DType Dtype> \
+ template <int Rank, TOSA_REF_TYPE Dtype> \
class Op##Opname : public UnaryNode<Rank, Dtype> \
{ \
public: \
- Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
+ Op##Opname(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) \
: UnaryNode<Rank, Dtype>(sgt_, Op_##OPNAME, id_) \
{ \
register_fcn(); \
@@ -75,7 +75,7 @@ DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT)
#undef DEF_TEMPLATE_UNARY_OP
// Negate is the only unary op with attributes
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpNegate : public UnaryNode<Rank, Dtype>
{
public:
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index 190b354..ca12cfe 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
OpResize<InDtype, OutDtype, resize_t>::OpResize(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -35,14 +35,14 @@ OpResize<InDtype, OutDtype, resize_t>::OpResize(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Resize);
}
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
OpResize<InDtype, OutDtype, resize_t>::~OpResize()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -64,7 +64,8 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
if (this->mode == ResizeMode_BILINEAR)
{
- if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
+ if (OutDtype != TOSA_REF_TYPE_INT32 && OutDtype != TOSA_REF_TYPE_INT48 && OutDtype != TOSA_REF_TYPE_FP32 &&
+ OutDtype != TOSA_REF_TYPE_FP16 && OutDtype != TOSA_REF_TYPE_BF16 && OutDtype != TOSA_REF_TYPE_FP64)
{
printNodeValidationError("OpResize: invalid data type for BILINEAR");
return 1;
@@ -72,7 +73,8 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
}
else
{
- if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16 && OutDtype != DType_BF16)
+ if (OutDtype != TOSA_REF_TYPE_INT8 && OutDtype != TOSA_REF_TYPE_INT16 && OutDtype != TOSA_REF_TYPE_FP32 &&
+ OutDtype != TOSA_REF_TYPE_FP16 && OutDtype != TOSA_REF_TYPE_BF16 && OutDtype != TOSA_REF_TYPE_FP64)
{
printNodeValidationError("OpResize: invalid data type for NEAREST");
return 1;
@@ -87,7 +89,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
int OpResize<InDtype, OutDtype, resize_t>::eval()
{
int in_batch = in->getShape()[0];
@@ -157,24 +159,38 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
int32_t y = oy * scale_y_d + offset_y;
int32_t x = ox * scale_x_d + offset_x;
- float fy = static_cast<float>(y) / static_cast<float>(scale_y_n);
- float fx = static_cast<float>(x) / static_cast<float>(scale_x_n);
-
- int32_t iy = floor(fy);
- int32_t ix = floor(fx);
-
+ int32_t iy;
+ int32_t ix;
resize_t dy;
resize_t dx;
- if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) ||
- (typeid(resize_t) == typeid(half_float::half)))
+ if (std::is_same<resize_t, double>::value)
{
- dy = (resize_t)(fy - iy);
- dx = (resize_t)(fx - ix);
+ const double fy_double = static_cast<double>(y) / static_cast<double>(scale_y_n);
+ const double fx_double = static_cast<double>(x) / static_cast<double>(scale_x_n);
+ iy = floor(fy_double);
+ ix = floor(fx_double);
+
+ dy = (resize_t)(fy_double - iy);
+ dx = (resize_t)(fx_double - ix);
}
else
{
- dy = (resize_t)(y - (iy * scale_y_n));
- dx = (resize_t)(x - (ix * scale_x_n));
+ const float fy = static_cast<float>(y) / static_cast<float>(scale_y_n);
+ const float fx = static_cast<float>(x) / static_cast<float>(scale_x_n);
+ iy = floor(fy);
+ ix = floor(fx);
+
+ if (std::is_floating_point<resize_t>::value || (typeid(resize_t) == typeid(Eigen::bfloat16)) ||
+ (typeid(resize_t) == typeid(half_float::half)))
+ {
+ dy = (resize_t)(fy - iy);
+ dx = (resize_t)(fx - ix);
+ }
+ else
+ {
+ dy = (resize_t)(y - (iy * scale_y_n));
+ dx = (resize_t)(x - (ix * scale_x_n));
+ }
}
int32_t iy0 = MAX(iy, 0);
@@ -248,3 +264,4 @@ DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT16, int16_t);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP16, FP16, half_float::half);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, BF16, BF16, Eigen::bfloat16);
DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP32, FP32, float);
+DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP64, FP64, double);
diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h
index 508d2c8..6d5a418 100644
--- a/reference_model/src/ops/image.h
+++ b/reference_model/src/ops/image.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-template <DType InDtype, DType OutDtype, typename resize_t>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype, typename resize_t>
class OpResize : public GraphNode
{
public:
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 1db3974..0a78884 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -37,11 +37,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
Op opType,
TosaAttributeBase* attribute,
uint64_t id,
- DType inputDType,
+ TOSA_REF_TYPE inputDTYPE,
int inputRank,
- DType outputDType,
+ TOSA_REF_TYPE outputDTYPE,
int outputRank,
- DType weightDType,
+ TOSA_REF_TYPE weightDTYPE,
int weightRank)
{
switch (opType)
@@ -53,6 +53,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
break;
case Op_AVG_POOL2D:
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16);
@@ -61,6 +62,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP64, FP64);
break;
case Op_CONV2D:
DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
@@ -70,6 +72,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
break;
case Op_CONV3D:
DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
@@ -79,6 +82,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
break;
case Op_DEPTHWISE_CONV2D:
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
@@ -88,9 +92,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
break;
case Op_FFT2D:
DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32);
+ DEF_FACTORY_ONE_TYPE(OpFFT2d, FP64);
break;
case Op_FULLY_CONNECTED:
DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
@@ -100,6 +106,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64);
break;
case Op_MATMUL:
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP16);
@@ -108,6 +115,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP32, FP32);
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32);
DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP64, FP64);
break;
case Op_MAX_POOL2D:
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
@@ -115,9 +123,11 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP64);
break;
case Op_RFFT2D:
DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32);
+ DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64);
break;
case Op_TRANSPOSE_CONV2D:
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
@@ -127,6 +137,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
break;
// activation_funcs
@@ -136,16 +147,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP64);
break;
case Op_SIGMOID:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP64);
break;
case Op_TANH:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP64);
break;
// ewise_binary
@@ -154,6 +168,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP64);
break;
case Op_ARITHMETIC_RIGHT_SHIFT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
@@ -202,12 +217,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP64);
break;
case Op_MINIMUM:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP64);
break;
case Op_MUL:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
@@ -216,17 +233,20 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP64, FP64);
break;
case Op_POW:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP64);
break;
case Op_SUB:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP64);
break;
case Op_TABLE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
@@ -239,6 +259,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP64);
break;
case Op_BITWISE_NOT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8);
@@ -249,6 +270,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP64);
break;
case Op_CLZ:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
@@ -257,16 +279,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP64);
break;
case Op_FLOOR:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP64);
break;
case Op_LOG:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP64);
break;
case Op_LOGICAL_NOT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
@@ -278,16 +303,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP64);
break;
case Op_RECIPROCAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP64);
break;
case Op_RSQRT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP64);
break;
// ewise_ternary
@@ -299,6 +327,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP64);
break;
// comparison
@@ -307,18 +336,21 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP64);
break;
case Op_GREATER:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP64);
break;
case Op_GREATER_EQUAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, BF16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP64);
break;
// reduction
@@ -335,6 +367,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64);
break;
case Op_REDUCE_MIN:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
@@ -343,16 +376,19 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64);
break;
case Op_REDUCE_PRODUCT:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64);
break;
case Op_REDUCE_SUM:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
break;
@@ -365,6 +401,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64);
break;
case Op_PAD:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
@@ -374,6 +411,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
break;
case Op_RESHAPE:
DEF_FACTORY_RESHAPE(OpReshape, FP16);
@@ -383,6 +421,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RESHAPE(OpReshape, INT16);
DEF_FACTORY_RESHAPE(OpReshape, INT32);
DEF_FACTORY_RESHAPE(OpReshape, BOOL);
+ DEF_FACTORY_RESHAPE(OpReshape, FP64);
break;
case Op_REVERSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
@@ -392,6 +431,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64);
break;
case Op_SLICE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
@@ -401,6 +441,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64);
break;
case Op_TILE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
@@ -410,6 +451,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64);
break;
case Op_TRANSPOSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
@@ -419,6 +461,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64);
break;
// scatter_gather
@@ -429,6 +472,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpGather, FP16);
DEF_FACTORY_ONE_TYPE(OpGather, BF16);
DEF_FACTORY_ONE_TYPE(OpGather, FP32);
+ DEF_FACTORY_ONE_TYPE(OpGather, FP64);
break;
case Op_SCATTER:
DEF_FACTORY_ONE_TYPE(OpScatter, INT8);
@@ -437,6 +481,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpScatter, FP16);
DEF_FACTORY_ONE_TYPE(OpScatter, BF16);
DEF_FACTORY_ONE_TYPE(OpScatter, FP32);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FP64);
break;
// image
@@ -448,6 +493,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16);
DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OpResize, BF16, BF16);
DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32);
+ DEF_FACTORY_TWO_TYPE_RESIZE_FP64(OpResize, FP64, FP64);
break;
// data_nodes
@@ -461,6 +507,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP64);
break;
// type_conversion
@@ -499,6 +546,13 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
break;
case Op_RESCALE:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index 9117df4..f276e03 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,19 +23,19 @@
#define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \
case RANK: \
- return new OP<RANK, DType_##DTYPE>(sgt, attribute, id);
+ return new OP<RANK, TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id);
#define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
case RANK: \
- return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id);
+ return new OP<RANK, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id);
#define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
case RANK2: \
- return new OP<RANK1, RANK2, DType_##DTYPE>(sgt, attribute, id);
+ return new OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id);
#define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
case RANK2: \
- return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id);
+ return new OP<RANK1, RANK2, TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id);
#define DEF_FACTORY_ONE_RANK_0_6(OP) \
switch (inputRank) \
@@ -57,40 +57,42 @@
}
#define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
- return new OP<DType_##DTYPE>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE>(sgt, attribute, id); \
}
#define DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE, ACCUM_DTYPE) \
- if (inputDType == DType_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \
{ \
- return new OP<DType_##DTYPE, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE, TOSA_REF_TYPE_##ACCUM_DTYPE>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_IN_OUT(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \
- if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \
- && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2 && \
+ ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACCUM_DTYPE) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
- } \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##ACCUM_DTYPE>(sgt, attribute, \
+ id); \
+ }
#define DEF_FACTORY_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \
- if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 && outputDType == DType_##DTYPE3) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && weightDTYPE == TOSA_REF_TYPE_##DTYPE2 && \
+ outputDTYPE == TOSA_REF_TYPE_##DTYPE3) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##DTYPE3>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>(sgt, attribute, id); \
}
// Statement-expression to evaluate accumulate attribute in-place
@@ -108,35 +110,41 @@
FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute.\nPre-initialization " \
"of this attribute is required in order to determine the accumulate type."); \
} \
- accumDType; \
- }) \
+ ConvertDType(accumDType); \
+ })
#define DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, int16_t>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, int16_t>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, half_float::half>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, half_float::half>(sgt, attribute, id); \
}
#define DEF_FACTORY_TWO_TYPE_RESIZE_BF16(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
+ { \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \
+ }
+
+#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, Eigen::bfloat16>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, float>(sgt, attribute, id); \
}
-#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+#define DEF_FACTORY_TWO_TYPE_RESIZE_FP64(OP, DTYPE1, DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
- return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \
+ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, double>(sgt, attribute, id); \
}
#define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
switch (inputRank) \
{ \
@@ -151,7 +159,7 @@
}
#define DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
switch (inputRank) \
{ \
@@ -165,7 +173,7 @@
}
#define DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
- if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE1 && outputDTYPE == TOSA_REF_TYPE_##DTYPE2) \
{ \
switch (inputRank) \
{ \
@@ -180,7 +188,7 @@
}
#define DEF_FACTORY_RESHAPE(OP, DTYPE) \
- if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \
+ if (inputDTYPE == TOSA_REF_TYPE_##DTYPE && outputDTYPE == TOSA_REF_TYPE_##DTYPE) \
{ \
switch (inputRank) \
{ \
@@ -292,11 +300,11 @@ public:
tosa::Op opType,
tosa::TosaAttributeBase* attribute,
uint64_t id,
- tosa::DType inputDType,
+ TOSA_REF_TYPE inputDTYPE,
int inputRank,
- tosa::DType outputDType,
+ TOSA_REF_TYPE outputDTYPE,
int outputRank,
- tosa::DType weightDType,
+ TOSA_REF_TYPE weightDTYPE,
int weightRank);
};
}; // namespace TosaReference
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index cd9d55f..bf8ba57 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, op_, id_)
{
@@ -30,14 +30,14 @@ ReduceNode<Rank, Dtype>::ReduceNode(SubgraphTraverser* sgt_, const Op& op_, Tosa
INIT_ATTRIBUTE(Axis);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
ReduceNode<Rank, Dtype>::~ReduceNode()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int ReduceNode<Rank, Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -100,7 +100,7 @@ struct AnyReducer {
bool finalize(const bool accum) const { return accum; }
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceAll<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().reduce(this->dims, AllReducer()).reshape(this->out->getTensor().dimensions());
@@ -108,7 +108,7 @@ int OpReduceAll<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceAny<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().reduce(this->dims, AnyReducer()).reshape(this->out->getTensor().dimensions());
@@ -116,7 +116,7 @@ int OpReduceAny<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceMax<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions());
@@ -124,7 +124,7 @@ int OpReduceMax<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceMin<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions());
@@ -132,35 +132,74 @@ int OpReduceMin<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceProduct<Rank, Dtype>::eval()
{
switch(Dtype)
{
- case DType_FP16:
- case DType_BF16:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
break;
- default:
+ case TOSA_REF_TYPE_FP32:
this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
+ }
+
+ return GraphNode::eval();
+}
+
+struct ProductDoubleReducer
+{
+ static const bool PacketAccess = false;
+ void reduce(const double val, double* accum)
+ {
+ *accum *= val;
+ }
+ double initialize() const
+ {
+ return 1.0;
+ }
+ double finalize(const double accum) const
+ {
+ return accum;
+ }
+};
+
+template <int Rank, TOSA_REF_TYPE Dtype>
+int OpReduceProductDouble<Rank, Dtype>::eval()
+{
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_FP64:
+ this->out->getTensor() = this->in->getTensor()
+ .reduce(this->dims, ProductDoubleReducer())
+ .reshape(this->out->getTensor().dimensions());
+ break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceSum<Rank, Dtype>::eval()
{
switch(Dtype)
{
- case DType_FP16:
- case DType_BF16:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()).unaryExpr([](float f){return fpTrunc<Dtype>(f);});
break;
- default:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_INT32:
this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
}
return GraphNode::eval();
@@ -183,7 +222,7 @@ struct SumRequiresReducer {
SubgraphTraverser* parent_sgt;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReduceSumInt<Rank, Dtype>::eval()
{
this->out->getTensor() = this->in->getTensor().reduce(this->dims, SumRequiresReducer(this->parent_sgt)).reshape(this->out->getTensor().dimensions());
@@ -191,6 +230,40 @@ int OpReduceSumInt<Rank, Dtype>::eval()
return GraphNode::eval();
}
+struct SumDoubleReducer
+{
+ static const bool PacketAccess = false;
+ void reduce(const double val, double* accum)
+ {
+ *accum += val;
+ }
+ double initialize() const
+ {
+ return 0.0;
+ }
+ double finalize(const double accum) const
+ {
+ return accum;
+ }
+};
+
+template <int Rank, TOSA_REF_TYPE Dtype>
+int OpReduceSumDouble<Rank, Dtype>::eval()
+{
+ switch (Dtype)
+ {
+ case TOSA_REF_TYPE_FP64:
+ this->out->getTensor() = this->in->getTensor()
+ .reduce(this->dims, SumDoubleReducer())
+ .reshape(this->out->getTensor().dimensions());
+ break;
+ default:
+ ERROR_IF(true, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(Dtype));
+ }
+
+ return GraphNode::eval();
+}
+
// template explicit instantiation
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
@@ -202,6 +275,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, BF16);
@@ -209,12 +283,15 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProductDouble, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumDouble, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h
index 6e98a76..aeb9f1d 100644
--- a/reference_model/src/ops/reduction.h
+++ b/reference_model/src/ops/reduction.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class ReduceNode : public GraphNode
{
public:
@@ -44,7 +44,7 @@ protected:
TosaAxisAttribute* attribute;
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceAll : public ReduceNode<Rank, Dtype>
{
public:
@@ -54,7 +54,7 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceAny : public ReduceNode<Rank, Dtype>
{
public:
@@ -64,7 +64,7 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceMax : public ReduceNode<Rank, Dtype>
{
public:
@@ -74,7 +74,7 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceMin : public ReduceNode<Rank, Dtype>
{
public:
@@ -84,7 +84,7 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceProduct : public ReduceNode<Rank, Dtype>
{
public:
@@ -94,7 +94,17 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
+class OpReduceProductDouble : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceProductDouble(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_PRODUCT, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceSum : public ReduceNode<Rank, Dtype>
{
public:
@@ -104,7 +114,7 @@ public:
virtual int eval();
};
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpReduceSumInt : public ReduceNode<Rank, Dtype>
{
public:
@@ -114,6 +124,16 @@ public:
virtual int eval();
};
+template <int Rank, TOSA_REF_TYPE Dtype>
+class OpReduceSumDouble : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceSumDouble(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(sgt_, Op_REDUCE_SUM, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
}; // namespace TosaReference
#endif
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index bcd8ce5..80b6c58 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -20,7 +20,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -29,11 +29,11 @@ OpGather<Dtype>::OpGather(SubgraphTraverser* sgt_,
setRequiredOperands(2, 1);
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpGather<Dtype>::~OpGather()
{}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpGather<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -96,7 +96,7 @@ int OpGather<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpGather<Dtype>::eval()
{
for (int32_t n = 0; n < N; n++)
@@ -116,7 +116,7 @@ int OpGather<Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -125,11 +125,11 @@ OpScatter<Dtype>::OpScatter(SubgraphTraverser* sgt_,
setRequiredOperands(3, 1);
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpScatter<Dtype>::~OpScatter()
{}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpScatter<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -199,7 +199,7 @@ int OpScatter<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpScatter<Dtype>::eval()
{
// Initializes the output tensor with the input value for values that are unchanged by the scatter operation.
@@ -229,6 +229,7 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32);
DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpGather, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, FP64);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
@@ -236,3 +237,4 @@ DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP64);
diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h
index af09153..fb675a9 100644
--- a/reference_model/src/ops/scatter_gather.h
+++ b/reference_model/src/ops/scatter_gather.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -23,7 +23,7 @@ using namespace tosa;
namespace TosaReference
{
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpGather : public GraphNode
{
public:
@@ -45,7 +45,7 @@ protected:
TosaReference::TensorTemplate<TOutput>* output;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpScatter : public GraphNode
{
public:
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
index ece14b1..6dd6e76 100644
--- a/reference_model/src/ops/template_types.h
+++ b/reference_model/src/ops/template_types.h
@@ -16,11 +16,10 @@
#ifndef OP_TEMPLATE_TYPES_H
#define OP_TEMPLATE_TYPES_H
-#include "tosa_generated.h"
-#include <Eigen/CXX11/Tensor>
+#include "dtype.h"
#include "half.hpp"
+#include <Eigen/CXX11/Tensor>
#include <Eigen/Core>
-#include "arith_util.h"
using namespace tosa;
@@ -64,213 +63,218 @@ using Tensor5 = TensorTemplate<ETensor5<T>>;
template <typename T>
using Tensor6 = TensorTemplate<ETensor6<T>>;
-template <DType type>
+template <TOSA_REF_TYPE type>
struct GetEigenType;
template <>
-struct GetEigenType<DType_FP32>
+struct GetEigenType<TOSA_REF_TYPE_FP64>
+{
+ using type = double;
+};
+template <>
+struct GetEigenType<TOSA_REF_TYPE_FP32>
{
using type = float;
};
template <>
-struct GetEigenType<DType_FP16>
+struct GetEigenType<TOSA_REF_TYPE_FP16>
{
// NOTE: full precision used
using type = float;
};
template <>
-struct GetEigenType<DType_BF16>
+struct GetEigenType<TOSA_REF_TYPE_BF16>
{
// NOTE: full precision used
using type = float;
};
template <>
-struct GetEigenType<DType_INT32>
+struct GetEigenType<TOSA_REF_TYPE_INT32>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_INT48>
+struct GetEigenType<TOSA_REF_TYPE_INT48>
{
using type = int64_t;
};
template <>
-struct GetEigenType<DType_BOOL>
+struct GetEigenType<TOSA_REF_TYPE_BOOL>
{
using type = bool;
};
template <>
-struct GetEigenType<DType_UINT8>
+struct GetEigenType<TOSA_REF_TYPE_UINT8>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_UINT16>
+struct GetEigenType<TOSA_REF_TYPE_UINT16>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_INT4>
+struct GetEigenType<TOSA_REF_TYPE_INT4>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_INT8>
+struct GetEigenType<TOSA_REF_TYPE_INT8>
{
using type = int32_t;
};
template <>
-struct GetEigenType<DType_INT16>
+struct GetEigenType<TOSA_REF_TYPE_INT16>
{
using type = int32_t;
};
/* Get Accumulate Eigen Type:
-Same behaviour as GetEigenType for all DTypes except the
-single specialised case of DType_FP16. */
-template <DType Dtype>
+Same behaviour as GetEigenType for all DTYPEs except the
+single specialised case of TOSA_REF_TYPE_FP16. */
+template <TOSA_REF_TYPE Dtype>
struct GetAccEigenType;
template <>
-struct GetAccEigenType<DType_FP16>
+struct GetAccEigenType<TOSA_REF_TYPE_FP16>
{
using type = half_float::half;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
struct GetAccEigenType
{
using type = typename GetEigenType<Dtype>::type;
};
// Meta function to get number of bits
-template <DType T>
+template <TOSA_REF_TYPE T>
struct GetNumBits
{
static constexpr int32_t value = 0;
};
template <>
-struct GetNumBits<DType_BOOL>
+struct GetNumBits<TOSA_REF_TYPE_BOOL>
{
static constexpr int32_t value = 1;
};
template <>
-struct GetNumBits<DType_UINT8>
+struct GetNumBits<TOSA_REF_TYPE_UINT8>
{
static constexpr int32_t value = 8;
};
template <>
-struct GetNumBits<DType_UINT16>
+struct GetNumBits<TOSA_REF_TYPE_UINT16>
{
static constexpr int32_t value = 16;
};
template <>
-struct GetNumBits<DType_INT4>
+struct GetNumBits<TOSA_REF_TYPE_INT4>
{
static constexpr int32_t value = 4;
};
template <>
-struct GetNumBits<DType_INT8>
+struct GetNumBits<TOSA_REF_TYPE_INT8>
{
static constexpr int32_t value = 8;
};
template <>
-struct GetNumBits<DType_INT16>
+struct GetNumBits<TOSA_REF_TYPE_INT16>
{
static constexpr int32_t value = 16;
};
template <>
-struct GetNumBits<DType_INT32>
+struct GetNumBits<TOSA_REF_TYPE_INT32>
{
static constexpr int32_t value = 32;
};
template <>
-struct GetNumBits<DType_INT48>
+struct GetNumBits<TOSA_REF_TYPE_INT48>
{
static constexpr int32_t value = 48;
};
template <>
-struct GetNumBits<DType_FP16>
+struct GetNumBits<TOSA_REF_TYPE_FP16>
{
static constexpr int32_t value = 16;
};
// Meta function to get quantized min/max in compile time
-template <DType T>
+template <TOSA_REF_TYPE T>
struct GetQMin
{
static constexpr int64_t value = INT64_C(0);
};
template <>
-struct GetQMin<DType_UINT8>
+struct GetQMin<TOSA_REF_TYPE_UINT8>
{
static constexpr int64_t value = INT64_C(0);
};
template <>
-struct GetQMin<DType_UINT16>
+struct GetQMin<TOSA_REF_TYPE_UINT16>
{
static constexpr int64_t value = INT64_C(0);
};
template <>
-struct GetQMin<DType_INT4>
+struct GetQMin<TOSA_REF_TYPE_INT4>
{
static constexpr int64_t value = INT64_C(-8);
};
template <>
-struct GetQMin<DType_INT8>
+struct GetQMin<TOSA_REF_TYPE_INT8>
{
static constexpr int64_t value = INT64_C(-128);
};
template <>
-struct GetQMin<DType_INT16>
+struct GetQMin<TOSA_REF_TYPE_INT16>
{
static constexpr int64_t value = INT64_C(-32768);
};
template <>
-struct GetQMin<DType_INT32>
+struct GetQMin<TOSA_REF_TYPE_INT32>
{
static constexpr int64_t value = -(INT64_C(1) << 31);
};
template <>
-struct GetQMin<DType_INT48>
+struct GetQMin<TOSA_REF_TYPE_INT48>
{
static constexpr int64_t value = -(INT64_C(1) << 47);
};
-template <DType T>
+template <TOSA_REF_TYPE T>
struct GetQMax
{
static constexpr int64_t value = INT64_C(0);
};
template <>
-struct GetQMax<DType_UINT8>
+struct GetQMax<TOSA_REF_TYPE_UINT8>
{
static constexpr int64_t value = INT64_C(255);
};
template <>
-struct GetQMax<DType_UINT16>
+struct GetQMax<TOSA_REF_TYPE_UINT16>
{
static constexpr int64_t value = INT64_C(65535);
};
template <>
-struct GetQMax<DType_INT4>
+struct GetQMax<TOSA_REF_TYPE_INT4>
{
static constexpr int64_t value = INT64_C(7);
};
template <>
-struct GetQMax<DType_INT8>
+struct GetQMax<TOSA_REF_TYPE_INT8>
{
static constexpr int64_t value = INT64_C(127);
};
template <>
-struct GetQMax<DType_INT16>
+struct GetQMax<TOSA_REF_TYPE_INT16>
{
static constexpr int64_t value = INT64_C(32767);
};
template <>
-struct GetQMax<DType_INT32>
+struct GetQMax<TOSA_REF_TYPE_INT32>
{
static constexpr int64_t value = (INT64_C(1) << 31) - 1;
};
template <>
-struct GetQMax<DType_INT48>
+struct GetQMax<TOSA_REF_TYPE_INT48>
{
static constexpr int64_t value = (INT64_C(1) << 47) - 1;
};
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index b3845df..f8fd323 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -116,14 +116,14 @@ int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute,
}
int check_conv_attribute(tosa::TosaConvAttribute* attribute,
- uint32_t conv_dimension,
- std::vector<int32_t> input_shape,
- std::vector<int32_t> output_shape,
- std::vector<int32_t> weights,
- uint32_t offset_kernel,
- DType InDtype,
- DType WeightDtype,
- std::string& msg)
+ uint32_t conv_dimension,
+ std::vector<int32_t> input_shape,
+ std::vector<int32_t> output_shape,
+ std::vector<int32_t> weights,
+ uint32_t offset_kernel,
+ TOSA_REF_TYPE InDtype,
+ TOSA_REF_TYPE WeightDtype,
+ std::string& msg)
{
if (attribute->pad().size() != (2 * conv_dimension))
{
@@ -226,11 +226,13 @@ int check_conv_attribute(tosa::TosaConvAttribute* attribute,
return 1;
}
- if (InDtype != DType_INT8 && attribute->input_zp() != 0) {
+ if (InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0)
+ {
msg = "Input zero point must be zero for non-int8 data";
return 1;
}
- if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) {
+ if (WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0)
+ {
msg = "Weight zero point must be zero for non-int8 data";
return 1;
}
@@ -318,7 +320,7 @@ int check_fft_shape(const std::vector<int32_t>& in_real,
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -330,14 +332,14 @@ OpArgMax<Rank, Dtype>::OpArgMax(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Axis);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpArgMax<Rank, Dtype>::~OpArgMax()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpArgMax<Rank, Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -355,7 +357,7 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes()
return 1;
}
- if (outputs[0]->getDtype() != DType_INT32)
+ if (outputs[0]->getDtype() != TOSA_REF_TYPE_INT32)
{
printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator");
return 1;
@@ -400,7 +402,7 @@ int OpArgMax<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpArgMax<Rank, Dtype>::eval()
{
Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
@@ -410,7 +412,7 @@ int OpArgMax<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -422,14 +424,14 @@ OpAvgPool2d<Dtype, AccDtype>::OpAvgPool2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Pool);
}
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
OpAvgPool2d<Dtype, AccDtype>::~OpAvgPool2d()
{
if (attribute)
delete attribute;
}
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -449,8 +451,10 @@ int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
- ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
+ "OpAvgPool2d: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0,
+ "OpAvgPool2d: Output zeropoint must be zero for non int8_t data");
std::string msg;
if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg))
@@ -466,8 +470,9 @@ int OpAvgPool2d<Dtype, AccDtype>::checkTensorAttributes()
// This calculates the number of padding elements used for each location along an axis
// Average pooling only divides by the number of elements used, not including padding.
// This function uses left/right, but is also used for vertical padding with top/bottom
-template <DType Dtype, DType AccDtype>
-ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
+ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(
+ int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right)
{
ETensor1<int32_t> result(out_size);
@@ -495,7 +500,7 @@ ETensor1<int32_t> OpAvgPool2d<Dtype, AccDtype>::calculate_div_map_1d(int in_size
// assuming input and output tensor have same scales like tflite reference
// so no need to scale input and output
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
int OpAvgPool2d<Dtype, AccDtype>::eval()
{
int in_batch = this->in->getShape()[0];
@@ -531,7 +536,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL");
LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL");
- tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
+ TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->accum_dtype());
DEBUG_INFO(OP,
"perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
@@ -556,7 +561,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
pad[3] = std::make_pair(0, 0);
ETensor4<InEigenType> input_val = this->in->getTensor();
- if (Dtype == DType_INT8)
+ if (Dtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
}
@@ -604,7 +609,8 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
dm2_h.contract(dm2_w, contract_dims)
.reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
.broadcast(bcast);
- if (Dtype != DType_FP32 && Dtype != DType_FP16 && Dtype != DType_BF16)
+ if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 &&
+ Dtype != TOSA_REF_TYPE_FP64)
{
try
{
@@ -632,7 +638,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -644,14 +650,14 @@ OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -688,7 +694,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
@@ -781,7 +787,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -817,7 +823,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
// reshape back to [N, H, W, C]
this->output->getTensor() = biased_output.reshape(col2im_output_dims);
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -826,7 +832,7 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -838,14 +844,14 @@ OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -882,7 +888,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
@@ -959,7 +965,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1020,7 +1026,7 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
}
}
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1029,10 +1035,10 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
: GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
{
setRequiredOperands(3, 1);
@@ -1041,14 +1047,14 @@ OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTra
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1085,7 +1091,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
@@ -1149,7 +1155,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1205,7 +1211,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
}
}
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1214,10 +1220,10 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
: GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
{
setRequiredOperands(3, 1);
@@ -1226,14 +1232,14 @@ OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTrave
INIT_ATTRIBUTE(FullyConnected);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1265,13 +1271,15 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
- ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
- ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
+ ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
+ "OpFullyConnected: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
+ "OpFullyConnected: Weight zeropoint must be zero for non int8_t data");
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
{
typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
@@ -1289,7 +1297,7 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -1299,7 +1307,7 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims).template cast<OutEigenType>() +
this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1307,7 +1315,7 @@ int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -1319,14 +1327,14 @@ OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(MatMul);
}
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
OpMatMul<Dtype, OutDtype>::~OpMatMul()
{
if (attribute)
delete attribute;
}
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1382,13 +1390,15 @@ int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
}
W = b->getShape()[2];
- ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data");
- ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->a_zp() != 0,
+ "OpMatMul: A zeropoint must be zero for non int8_t data");
+ ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->b_zp() != 0,
+ "OpMatMul: B zeropoint must be zero for non int8_t data");
return 0;
}
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
int OpMatMul<Dtype, OutDtype>::eval()
{
typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
@@ -1396,7 +1406,7 @@ int OpMatMul<Dtype, OutDtype>::eval()
TIn a_val = this->a->getTensor();
TIn b_val = this->b->getTensor();
- if (Dtype == DType_INT8)
+ if (Dtype == TOSA_REF_TYPE_INT8)
{
a_val = a_val - (InEigenType)attribute->a_zp();
b_val = b_val - (InEigenType)attribute->b_zp();
@@ -1434,7 +1444,7 @@ int OpMatMul<Dtype, OutDtype>::eval()
}
}
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1443,7 +1453,7 @@ int OpMatMul<Dtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -1455,14 +1465,14 @@ OpMaxPool2d<Dtype>::OpMaxPool2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Pool);
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpMaxPool2d<Dtype>::~OpMaxPool2d()
{
if (attribute)
delete attribute;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpMaxPool2d<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1493,7 +1503,7 @@ int OpMaxPool2d<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpMaxPool2d<Dtype>::eval()
{
int in_batch = this->in->getShape()[0];
@@ -1586,10 +1596,8 @@ int OpMaxPool2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
-OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+template <TOSA_REF_TYPE Dtype>
+OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_)
: GraphNode(sgt_, Op_FFT2D, id_)
{
setRequiredOperands(2, 2);
@@ -1598,14 +1606,14 @@ OpFFT2d<Dtype>::OpFFT2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(FFT);
}
-template <DType Dtype>
-OpFFT2d<Dtype>::~OpFFT2d() {
+template <TOSA_REF_TYPE Dtype>
+OpFFT2d<Dtype>::~OpFFT2d()
+{
if (attribute)
delete attribute;
}
-
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpFFT2d<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1643,7 +1651,7 @@ int OpFFT2d<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpFFT2d<Dtype>::eval()
{
int in_real_batch = this->in_real->getShape()[0];
@@ -1709,7 +1717,7 @@ int OpFFT2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -1719,11 +1727,11 @@ OpRFFT2d<Dtype>::OpRFFT2d(SubgraphTraverser* sgt_,
setRequiredRank(3);
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
OpRFFT2d<Dtype>::~OpRFFT2d() {}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpRFFT2d<Dtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1759,7 +1767,7 @@ int OpRFFT2d<Dtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpRFFT2d<Dtype>::eval()
{
int32_t in_batch = in->getShape()[0];
@@ -1815,10 +1823,10 @@ int OpRFFT2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
- TosaAttributeBase* attribute_,
- uint64_t id_)
+ TosaAttributeBase* attribute_,
+ uint64_t id_)
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
{
setRequiredOperands(3, 1);
@@ -1827,14 +1835,14 @@ OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTra
INIT_ATTRIBUTE(TransposeConv);
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
@@ -1923,13 +1931,15 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
return 1;
}
- ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
- ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
+ ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0,
+ "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data");
+ ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0,
+ "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data");
return 0;
}
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
@@ -1985,7 +1995,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
- if (InDtype == DType_INT8 || WeightDtype == DType_INT8)
+ if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8)
{
input_val = input_val - (InEigenType)attribute->input_zp();
weight_val = weight_val - (WeightEigenType)attribute->weight_zp();
@@ -2040,7 +2050,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
}
}
- if (OutDtype == DType_INT48)
+ if (OutDtype == TOSA_REF_TYPE_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -2055,6 +2065,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
@@ -2062,6 +2073,7 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64);
// [in_t, weight_t, out_t]
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
@@ -2071,6 +2083,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
@@ -2079,6 +2092,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
@@ -2087,8 +2101,10 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64);
DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32);
+DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
@@ -2097,6 +2113,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
@@ -2104,14 +2121,17 @@ DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64);
DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
+DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
@@ -2120,3 +2140,4 @@ DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index 9ef4a58..df53f2b 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -24,7 +24,7 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
class OpArgMax : public GraphNode
{
public:
@@ -35,7 +35,7 @@ public:
virtual int eval();
using InEigenType = typename GetEigenType<Dtype>::type;
- using OutEigenType = typename GetEigenType<DType_INT32>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_INT32>::type;
using TIn = Eigen::Tensor<InEigenType, Rank>;
using TOut = Eigen::Tensor<OutEigenType, Rank - 1>;
@@ -45,7 +45,7 @@ protected:
TosaReference::TensorTemplate<TOut>* output;
};
-template <DType Dtype, DType AccDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE AccDtype>
class OpAvgPool2d : public GraphNode
{
public:
@@ -74,7 +74,7 @@ protected:
ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpConv2d : public GraphNode
{
public:
@@ -104,7 +104,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpConv3d : public GraphNode
{
public:
@@ -134,7 +134,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpDepthwiseConv2d : public GraphNode
{
public:
@@ -164,7 +164,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpFullyConnected : public GraphNode
{
public:
@@ -195,7 +195,7 @@ protected:
tosa::TosaFullyConnectedAttribute* attribute;
};
-template <DType Dtype, DType OutDtype>
+template <TOSA_REF_TYPE Dtype, TOSA_REF_TYPE OutDtype>
class OpMatMul : public GraphNode
{
public:
@@ -227,7 +227,7 @@ protected:
tosa::TosaMatMulAttribute* attribute;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpMaxPool2d : public GraphNode
{
public:
@@ -248,7 +248,7 @@ protected:
tosa::TosaPoolAttribute* attribute;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpFFT2d : public GraphNode
{
public:
@@ -271,7 +271,7 @@ protected:
tosa::TosaFFTAttribute* attribute;
};
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
class OpRFFT2d : public GraphNode
{
public:
@@ -292,7 +292,7 @@ protected:
TosaReference::TensorTemplate<TOut>* out_imag;
};
-template <DType InDtype, DType WeightDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype>
class OpTransposeConv2d : public GraphNode
{
public:
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 9034add..68ffb1f 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -24,7 +24,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -35,14 +35,14 @@ OpRescale<Rank, InDtype, OutDtype>::OpRescale(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Rescale);
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -69,31 +69,33 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
ASSERT_MEM(in && out);
- if ((InDtype != DType_INT8) && (InDtype != DType_UINT8) && (InDtype != DType_UINT16) && (attribute->input_zp() != 0))
+ if ((InDtype != TOSA_REF_TYPE_INT8) && (InDtype != TOSA_REF_TYPE_UINT8) && (InDtype != TOSA_REF_TYPE_UINT16) &&
+ (attribute->input_zp() != 0))
{
- printNodeValidationError("OpRescale: Input DType not INT8/UINT8/UINT16 and zero point not 0");
+ printNodeValidationError("OpRescale: Input TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
return 1;
}
- if ((OutDtype != DType_INT8) && (OutDtype != DType_UINT8) && (OutDtype != DType_UINT16) && (attribute->output_zp() != 0))
+ if ((OutDtype != TOSA_REF_TYPE_INT8) && (OutDtype != TOSA_REF_TYPE_UINT8) && (OutDtype != TOSA_REF_TYPE_UINT16) &&
+ (attribute->output_zp() != 0))
{
- printNodeValidationError("OpRescale: Output DType not INT8/UINT8/UINT16 and zero point not 0");
+ printNodeValidationError("OpRescale: Output TOSA_REF_TYPE not INT8/UINT8/UINT16 and zero point not 0");
return 1;
}
- if ((InDtype == DType_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
+ if ((InDtype == TOSA_REF_TYPE_UINT16) && ((attribute->input_zp() != 0) && (attribute->input_zp() != 32768)))
{
- printNodeValidationError("OpRescale: Input DType UINT16 and zero point not 0 or 32768");
+ printNodeValidationError("OpRescale: Input TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
return 1;
}
- if ((OutDtype == DType_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
+ if ((OutDtype == TOSA_REF_TYPE_UINT16) && ((attribute->output_zp() != 0) && (attribute->output_zp() != 32768)))
{
- printNodeValidationError("OpRescale: Output DType UINT16 and zero point not 0 or 32768");
+ printNodeValidationError("OpRescale: Output TOSA_REF_TYPE UINT16 and zero point not 0 or 32768");
return 1;
}
- if (attribute->scale32() && (InDtype == DType_INT48))
+ if (attribute->scale32() && (InDtype == TOSA_REF_TYPE_INT48))
{
printNodeValidationError("OpRescale: Scale set to true but input type is INT48");
return 1;
@@ -108,7 +110,7 @@ int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpRescale<Rank, InDtype, OutDtype>::eval()
{
int32_t input_zp = attribute->input_zp();
@@ -237,7 +239,7 @@ int OpRescale<Rank, InDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -247,11 +249,11 @@ OpCast<Rank, InDtype, OutDtype>::OpCast(SubgraphTraverser* sgt_,
setRequiredRank(0, 6);
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
OpCast<Rank, InDtype, OutDtype>::~OpCast()
{}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -281,7 +283,7 @@ int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
int OpCast<Rank, InDtype, OutDtype>::eval()
{
this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
@@ -289,7 +291,7 @@ int OpCast<Rank, InDtype, OutDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
CastHelper<InDtype, OutDtype>::CastHelper()
{
fcn = [](InEigenType in) -> OutEigenType {
@@ -298,14 +300,14 @@ CastHelper<InDtype, OutDtype>::CastHelper()
};
}
-template <DType InDtype>
-CastHelper<InDtype, DType_BOOL>::CastHelper()
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_BOOL>::CastHelper()
{
fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
}
-template <DType OutDtype>
-CastHelper<DType_BOOL, OutDtype>::CastHelper()
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_BOOL, OutDtype>::CastHelper()
{
fcn = [](bool in) -> OutEigenType {
OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
@@ -313,8 +315,8 @@ CastHelper<DType_BOOL, OutDtype>::CastHelper()
};
}
-template <DType InDtype>
-CastHelper<InDtype, DType_FP16>::CastHelper()
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_FP16>::CastHelper()
{
// Integer data converted to fp16 (stored as fp32)
fcn = [](InEigenType in) -> float {
@@ -324,17 +326,17 @@ CastHelper<InDtype, DType_FP16>::CastHelper()
};
}
-CastHelper<DType_FP32, DType_FP16>::CastHelper()
+CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP16>::CastHelper()
{
// fp32 data converted to fp16 (stored as fp32)
fcn = [](float in) -> float {
- float out = fpTrunc<DType_FP16>(in); // truncate required for conversion from higher precision
+ float out = fpTrunc<TOSA_REF_TYPE_FP16>(in); // truncate required for conversion from higher precision
return out;
};
}
-template <DType InDtype>
-CastHelper<InDtype, DType_BF16>::CastHelper()
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_BF16>::CastHelper()
{
// Integer data converted to bf16 (stored as fp32)
fcn = [](InEigenType in) -> float {
@@ -343,16 +345,16 @@ CastHelper<InDtype, DType_BF16>::CastHelper()
};
}
-CastHelper<DType_FP32, DType_BF16>::CastHelper()
+CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_BF16>::CastHelper()
{
// fp32 data converted to bf16 (stored as fp32)
fcn = [](float in) -> float {
- return fpTrunc<DType_BF16>(in); // truncate required for conversions from higher precision
+ return fpTrunc<TOSA_REF_TYPE_BF16>(in); // truncate required for conversions from higher precision
};
}
-template <DType OutDtype>
-CastHelper<DType_FP16, OutDtype>::CastHelper()
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP16, OutDtype>::CastHelper()
{
// fp16 data (stored as fp32) converted to integer
fcn = [](float in) -> OutEigenType {
@@ -366,7 +368,7 @@ CastHelper<DType_FP16, OutDtype>::CastHelper()
};
}
-CastHelper<DType_FP16, DType_FP32>::CastHelper()
+CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP32>::CastHelper()
{
// No-op since fp16 values treated internally as their fp32 representation
fcn = [](float in) -> OutEigenType {
@@ -374,8 +376,8 @@ CastHelper<DType_FP16, DType_FP32>::CastHelper()
};
}
-template <DType OutDtype>
-CastHelper<DType_BF16, OutDtype>::CastHelper()
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_BF16, OutDtype>::CastHelper()
{
// bf16 data (stored as fp32) converted to integer
fcn = [](float in) -> OutEigenType {
@@ -386,7 +388,7 @@ CastHelper<DType_BF16, OutDtype>::CastHelper()
};
}
-CastHelper<DType_BF16, DType_FP32>::CastHelper()
+CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP32>::CastHelper()
{
// No-op since bf16 values treated as truncated fp32 internally
fcn = [](InEigenType in) -> OutEigenType {
@@ -394,8 +396,8 @@ CastHelper<DType_BF16, DType_FP32>::CastHelper()
};
}
-template <DType InDtype>
-CastHelper<InDtype, DType_FP32>::CastHelper()
+template <TOSA_REF_TYPE InDtype>
+CastHelper<InDtype, TOSA_REF_TYPE_FP32>::CastHelper()
{
// Integer data converted to fp32
fcn = [](InEigenType in) -> float {
@@ -404,8 +406,8 @@ CastHelper<InDtype, DType_FP32>::CastHelper()
};
}
-template <DType OutDtype>
-CastHelper<DType_FP32, OutDtype>::CastHelper()
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP32, OutDtype>::CastHelper()
{
// fp32 data converted to integer
fcn = [](float in) -> OutEigenType {
@@ -416,6 +418,31 @@ CastHelper<DType_FP32, OutDtype>::CastHelper()
};
}
+template <TOSA_REF_TYPE OutDtype>
+CastHelper<TOSA_REF_TYPE_FP64, OutDtype>::CastHelper()
+{
+ switch (OutDtype)
+ {
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
+ // fp64 data converted to integer
+ fcn = [](InEigenType in) -> OutEigenType {
+ OutEigenType out = std::rint(in);
+ out = std::max<OutEigenType>(out, OutMin);
+ out = std::min<OutEigenType>(out, OutMax);
+ return out;
+ };
+ break;
+ case TOSA_REF_TYPE_FP64:
+ // no op
+ fcn = [](InEigenType in) -> OutEigenType { return in; };
+ break;
+ default:
+ ASSERT_MSG(false, "unsupported TOSA_REF_TYPE %s", EnumNameTOSAREFTYPE(OutDtype));
+ }
+}
+
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
@@ -451,6 +478,13 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, FP16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, BF16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP64, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP64);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP64);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index e2fc6e2..98799a0 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020-2022, ARM Limited.
+// Copyright (c) 2020-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -22,7 +22,7 @@ using namespace tosa;
namespace TosaReference
{
-template <int Rank, DType InDtype, DType OutDtype>
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class OpRescale : public GraphNode
{
public:
@@ -46,7 +46,7 @@ protected:
TosaReference::TensorTemplate<TOut>* out;
};
-template <DType InDtype, DType OutDtype>
+template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class CastHelper
{
public:
@@ -64,12 +64,12 @@ private:
FcnType fcn;
};
-template <DType InDtype>
-class CastHelper<InDtype, DType_BOOL>
+template <TOSA_REF_TYPE InDtype>
+class CastHelper<InDtype, TOSA_REF_TYPE_BOOL>
{
public:
using InEigenType = typename GetEigenType<InDtype>::type;
- using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -81,11 +81,11 @@ private:
FcnType fcn;
};
-template <DType OutDtype>
-class CastHelper<DType_BOOL, OutDtype>
+template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_BOOL, OutDtype>
{
public:
- using InEigenType = typename GetEigenType<DType_BOOL>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BOOL>::type;
using OutEigenType = typename GetEigenType<OutDtype>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
@@ -100,12 +100,12 @@ private:
FcnType fcn;
};
-template <DType InDtype>
-class CastHelper<InDtype, DType_FP16>
+template <TOSA_REF_TYPE InDtype>
+class CastHelper<InDtype, TOSA_REF_TYPE_FP16>
{
public:
using InEigenType = typename GetEigenType<InDtype>::type;
- using OutEigenType = typename GetEigenType<DType_FP16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -117,11 +117,11 @@ private:
FcnType fcn;
};
-template <DType OutDtype>
-class CastHelper<DType_FP16, OutDtype>
+template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_FP16, OutDtype>
{
public:
- using InEigenType = typename GetEigenType<DType_FP16>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
using OutEigenType = typename GetEigenType<OutDtype>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
@@ -137,11 +137,11 @@ private:
};
template <>
-class CastHelper<DType_FP32, DType_FP16>
+class CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_FP16>
{
public:
- using InEigenType = typename GetEigenType<DType_FP32>::type;
- using OutEigenType = typename GetEigenType<DType_FP16>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -153,12 +153,12 @@ private:
FcnType fcn;
};
-template <DType InDtype>
-class CastHelper<InDtype, DType_BF16>
+template <TOSA_REF_TYPE InDtype>
+class CastHelper<InDtype, TOSA_REF_TYPE_BF16>
{
public:
using InEigenType = typename GetEigenType<InDtype>::type;
- using OutEigenType = typename GetEigenType<DType_BF16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -170,11 +170,11 @@ private:
FcnType fcn;
};
-template <DType OutDtype>
-class CastHelper<DType_BF16, OutDtype>
+template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_BF16, OutDtype>
{
public:
- using InEigenType = typename GetEigenType<DType_BF16>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
using OutEigenType = typename GetEigenType<OutDtype>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
@@ -190,11 +190,11 @@ private:
};
template <>
-class CastHelper<DType_FP32, DType_BF16>
+class CastHelper<TOSA_REF_TYPE_FP32, TOSA_REF_TYPE_BF16>
{
public:
- using InEigenType = typename GetEigenType<DType_FP32>::type;
- using OutEigenType = typename GetEigenType<DType_BF16>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -206,12 +206,12 @@ private:
FcnType fcn;
};
-template <DType InDtype>
-class CastHelper<InDtype, DType_FP32>
+template <TOSA_REF_TYPE InDtype>
+class CastHelper<InDtype, TOSA_REF_TYPE_FP32>
{
public:
using InEigenType = typename GetEigenType<InDtype>::type;
- using OutEigenType = typename GetEigenType<DType_FP32>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -224,11 +224,11 @@ private:
};
template <>
-class CastHelper<DType_FP16, DType_FP32>
+class CastHelper<TOSA_REF_TYPE_FP16, TOSA_REF_TYPE_FP32>
{
public:
- using InEigenType = typename GetEigenType<DType_FP16>::type;
- using OutEigenType = typename GetEigenType<DType_FP32>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -241,11 +241,11 @@ private:
};
template <>
-class CastHelper<DType_BF16, DType_FP32>
+class CastHelper<TOSA_REF_TYPE_BF16, TOSA_REF_TYPE_FP32>
{
public:
- using InEigenType = typename GetEigenType<DType_BF16>::type;
- using OutEigenType = typename GetEigenType<DType_FP32>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_BF16>::type;
+ using OutEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -257,11 +257,11 @@ private:
FcnType fcn;
};
-template <DType OutDtype>
-class CastHelper<DType_FP32, OutDtype>
+template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_FP32, OutDtype>
{
public:
- using InEigenType = typename GetEigenType<DType_FP32>::type;
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP32>::type;
using OutEigenType = typename GetEigenType<OutDtype>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
@@ -276,7 +276,26 @@ private:
FcnType fcn;
};
-template <int Rank, DType InDtype, DType OutDtype>
+template <TOSA_REF_TYPE OutDtype>
+class CastHelper<TOSA_REF_TYPE_FP64, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<TOSA_REF_TYPE_FP64>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <int Rank, TOSA_REF_TYPE InDtype, TOSA_REF_TYPE OutDtype>
class OpCast : public GraphNode
{
public: