aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2022-09-27 13:50:00 +0100
committerJeremy Johnson <jeremy.johnson@arm.com>2022-10-13 18:21:15 +0100
commitbc2a3db54ecee48fe2236f7fc03da8fd07d81ca0 (patch)
treec3908f23c369fd3226e840f81c3ba4b49cc409a0
parent93d4390f9aa5c4369f889e1cd336aa4e809ff6a7 (diff)
downloadreference_model-bc2a3db54ecee48fe2236f7fc03da8fd07d81ca0.tar.gz
Rename FLOAT type to FP32
Update tensor operations naming to state input type as TxT in all cases. Effects CONV2D, CONV3D, DEPTHWISE_CONV2D, FULLY_CONNECTED, TRANSPOSE_CONV2D. Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: Ic959acfcb3aa0a910b33b774a5a85fac08219205
-rw-r--r--reference_model/src/ops/activation_funcs.cc12
-rw-r--r--reference_model/src/ops/comparison.cc12
-rw-r--r--reference_model/src/ops/data_layout.cc16
-rw-r--r--reference_model/src/ops/data_nodes.cc2
-rw-r--r--reference_model/src/ops/ewise_binary.cc26
-rw-r--r--reference_model/src/ops/ewise_binary.h4
-rw-r--r--reference_model/src/ops/ewise_ternary.cc2
-rw-r--r--reference_model/src/ops/ewise_unary.cc32
-rw-r--r--reference_model/src/ops/image.cc6
-rw-r--r--reference_model/src/ops/op_factory.cc116
-rw-r--r--reference_model/src/ops/op_factory.h2
-rw-r--r--reference_model/src/ops/reduction.cc8
-rw-r--r--reference_model/src/ops/scatter_gather.cc4
-rw-r--r--reference_model/src/ops/template_types.h6
-rw-r--r--reference_model/src/ops/tensor_ops.cc34
-rw-r--r--reference_model/src/ops/type_conversion.cc16
-rw-r--r--reference_model/src/ops/type_conversion.h8
-rw-r--r--reference_model/src/subgraph_traverser.cc4
-rw-r--r--reference_model/src/tensor.cc12
-rw-r--r--reference_model/src/tensor.h2
m---------thirdparty/serialization_lib0
-rw-r--r--verif/generator/tosa_arg_gen.py38
-rw-r--r--verif/generator/tosa_error_if.py56
-rw-r--r--verif/generator/tosa_test_gen.py148
-rw-r--r--verif/generator/tosa_utils.py25
-rw-r--r--verif/tests/test_tosa_refmodel.py10
26 files changed, 296 insertions, 305 deletions
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
index 1c0c23a..61f7df6 100644
--- a/reference_model/src/ops/activation_funcs.cc
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -28,7 +28,7 @@ int OpClamp<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
{
InEigenType min = (InEigenType)attribute->min_fp();
InEigenType max = (InEigenType)attribute->max_fp();
@@ -59,7 +59,7 @@ int OpSigmoid<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); };
break;
default:
@@ -75,7 +75,7 @@ int OpTanh<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); };
break;
default:
@@ -87,12 +87,12 @@ int OpTanh<Rank, Dtype>::register_fcn()
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc
index 5930c1a..f240aa5 100644
--- a/reference_model/src/ops/comparison.cc
+++ b/reference_model/src/ops/comparison.cc
@@ -28,7 +28,7 @@ int OpEqual<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; };
break;
@@ -45,7 +45,7 @@ int OpGreater<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; };
break;
@@ -62,7 +62,7 @@ int OpGreaterEqual<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; };
break;
@@ -75,13 +75,13 @@ int OpGreaterEqual<Rank, Dtype>::register_fcn()
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index 1ed0be2..69b6a65 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -191,7 +191,7 @@ int OpPad<Rank, Dtype>::eval()
pad_value = (InEigenType)attribute->pad_const_int();
break;
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
pad_value = (InEigenType)attribute->pad_const_fp();
break;
default:
@@ -639,49 +639,49 @@ int OpTranspose<Rank, Dtype>::eval()
// template explicit instantiation
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16)
-DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
-DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
-DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT);
+DEF_INSTANTIATE_RESHAPE(OpReshape, FP32);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
-DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
index 4ff08be..5709a92 100644
--- a/reference_model/src/ops/data_nodes.cc
+++ b/reference_model/src/ops/data_nodes.cc
@@ -90,7 +90,7 @@ int OpIdentity<Rank, Dtype>::eval()
// note OpConst is not templated
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
index 917d56e..098b0ea 100644
--- a/reference_model/src/ops/ewise_binary.cc
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -143,7 +143,7 @@ int OpAdd<Rank, Dtype>::register_fcn()
};
break;
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
break;
default:
@@ -371,7 +371,7 @@ int OpMaximum<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
break;
@@ -388,7 +388,7 @@ int OpMinimum<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
case DType_INT32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
break;
@@ -407,7 +407,7 @@ int OpMul<Rank, InDtype, OutDtype>::register_fcn()
switch (InDtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
this->fcn = [shift](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
break;
case DType_INT32:
@@ -457,7 +457,7 @@ int OpPow<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
break;
default:
@@ -482,7 +482,7 @@ int OpSub<Rank, Dtype>::register_fcn()
};
break;
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
break;
default:
@@ -581,7 +581,7 @@ int OpTable<Rank, InDtype>::eval()
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
@@ -617,24 +617,24 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT8);
@@ -643,5 +643,5 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTable, INT16);
// Instantiation of nodes for comparison operators opEqual, opGreater
// and opGreaterEqual
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP16, BOOL);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FP32, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
index b2c92a4..6b0efaf 100644
--- a/reference_model/src/ops/ewise_binary.h
+++ b/reference_model/src/ops/ewise_binary.h
@@ -1,5 +1,5 @@
-// Copyright (c) 2020, ARM Limited.
+// Copyright (c) 2020-2022, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -36,7 +36,7 @@ namespace TosaReference
// Eigen::Tensor does support some binary element-wise natively (e.g. CWiseMax, or '+', etc.)
// which might be faster since it could be implemented with SIMD instructions
// the way of registering lambda + .binaryExpr() might sacrifice performance here
-// but it can avoid partially specialization for combination of {rankN, rank0} x {FLOAT/INT32, QU8, ...}
+// but it can avoid partially specialization for combination of {rankN, rank0} x {FP32/INT32, QU8, ...}
// needs to revisit if performance becomes a bottleneck here
template <int Rank, DType InDtype, DType OutDtype>
class BinaryNodeBase : public GraphNode
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
index da046a7..d85da1a 100644
--- a/reference_model/src/ops/ewise_ternary.cc
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -108,7 +108,7 @@ int OpSelect<0, Dtype>::eval()
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
index 52f5aff..00897cc 100644
--- a/reference_model/src/ops/ewise_unary.cc
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -78,7 +78,7 @@ int OpAbs<Rank, Dtype>::register_fcn()
{
switch (Dtype)
{
- case DType_FLOAT:
+ case DType_FP32:
case DType_FP16:
case DType_INT32:
this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
@@ -113,7 +113,7 @@ int OpCeil<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); };
break;
default:
@@ -161,7 +161,7 @@ int OpExp<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); };
break;
default:
@@ -177,7 +177,7 @@ int OpFloor<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); };
break;
default:
@@ -193,7 +193,7 @@ int OpLog<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); };
break;
default:
@@ -245,7 +245,7 @@ int OpNegate<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
this->fcn = [](InEigenType a) -> OutEigenType {
InEigenType result = -(a);
return result;
@@ -297,7 +297,7 @@ int OpReciprocal<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; };
break;
default:
@@ -313,7 +313,7 @@ int OpRsqrt<Rank, Dtype>::register_fcn()
switch (Dtype)
{
case DType_FP16:
- case DType_FLOAT:
+ case DType_FP32:
this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); };
break;
default:
@@ -325,7 +325,7 @@ int OpRsqrt<Rank, Dtype>::register_fcn()
// template explicit instantiation
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT8);
@@ -333,29 +333,29 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index 891261b..cf1d9f7 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -63,7 +63,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
if (this->mode == ResizeMode_BILINEAR)
{
- if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FLOAT && OutDtype != DType_FP16)
+ if (OutDtype != DType_INT32 && OutDtype != DType_INT48 && OutDtype != DType_FP32 && OutDtype != DType_FP16)
{
printNodeValidationError("OpResize: invalid data type for BILINEAR");
return 1;
@@ -71,7 +71,7 @@ int OpResize<InDtype, OutDtype, resize_t>::checkTensorAttributes()
}
else
{
- if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FLOAT && OutDtype != DType_FP16)
+ if (OutDtype != DType_INT8 && OutDtype != DType_INT16 && OutDtype != DType_FP32 && OutDtype != DType_FP16)
{
printNodeValidationError("OpResize: invalid data type for NEAREST");
return 1;
@@ -225,4 +225,4 @@ DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t);
DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, float);
-DEF_INSTANTIATE_THREE_TYPE(OpResize, FLOAT, FLOAT, float);
+DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float);
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index fd73eb5..1ff8229 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -49,66 +49,66 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// tensor_ops
case Op_ARGMAX:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
break;
case Op_AVG_POOL2D:
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP16);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FLOAT);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FLOAT, FLOAT);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP16, FP32);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP32, FP32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT8, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
break;
case Op_CONV2D:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FLOAT);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48);
break;
case Op_CONV3D:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FLOAT);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48);
break;
case Op_DEPTHWISE_CONV2D:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FLOAT);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48);
break;
case Op_FULLY_CONNECTED:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FLOAT);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT16, INT8, INT48);
break;
case Op_MATMUL:
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FLOAT);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FLOAT, FLOAT);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP32);
+ DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP32, FP32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32);
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48);
break;
case Op_MAX_POOL2D:
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
- DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FLOAT);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP32);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT8);
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
break;
case Op_TRANSPOSE_CONV2D:
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FLOAT);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FLOAT, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32);
DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48);
@@ -117,23 +117,23 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// activation_funcs
case Op_CLAMP:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
break;
case Op_SIGMOID:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FP32);
break;
case Op_TANH:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FP32);
break;
// ewise_binary
case Op_ADD:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
break;
case Op_ARITHMETIC_RIGHT_SHIFT:
@@ -180,28 +180,28 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_MAXIMUM:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
break;
case Op_MINIMUM:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
break;
case Op_MUL:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP16, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FP32, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
break;
case Op_POW:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FP32);
break;
case Op_SUB:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
break;
case Op_TABLE:
@@ -212,7 +212,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// ewise_unary
case Op_ABS:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
break;
case Op_BITWISE_NOT:
@@ -222,46 +222,46 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_CEIL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FP32);
break;
case Op_CLZ:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
break;
case Op_EXP:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FP32);
break;
case Op_FLOOR:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FP32);
break;
case Op_LOG:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FP32);
break;
case Op_LOGICAL_NOT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
break;
case Op_NEGATE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
break;
case Op_RECIPROCAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FP32);
break;
case Op_RSQRT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FP32);
break;
// ewise_ternary
case Op_SELECT:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
@@ -271,17 +271,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
// comparison
case Op_EQUAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
break;
case Op_GREATER:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
break;
case Op_GREATER_EQUAL:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
break;
@@ -294,32 +294,32 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_REDUCE_MAX:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
break;
case Op_REDUCE_MIN:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
break;
case Op_REDUCE_PRODUCT:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
break;
case Op_REDUCE_SUM:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
break;
// data layout
case Op_CONCAT:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP16);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32);
@@ -327,7 +327,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_PAD:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
@@ -335,7 +335,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_RESHAPE:
DEF_FACTORY_RESHAPE(OpReshape, FP16);
- DEF_FACTORY_RESHAPE(OpReshape, FLOAT);
+ DEF_FACTORY_RESHAPE(OpReshape, FP32);
DEF_FACTORY_RESHAPE(OpReshape, INT8);
DEF_FACTORY_RESHAPE(OpReshape, INT16);
DEF_FACTORY_RESHAPE(OpReshape, INT32);
@@ -343,7 +343,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_REVERSE:
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
- DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP32);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
@@ -351,7 +351,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_SLICE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
@@ -359,7 +359,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
break;
case Op_TILE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
@@ -368,7 +368,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
case Op_TRANSPOSE:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
@@ -380,14 +380,14 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpGather, INT16);
DEF_FACTORY_ONE_TYPE(OpGather, INT32);
DEF_FACTORY_ONE_TYPE(OpGather, FP16);
- DEF_FACTORY_ONE_TYPE(OpGather, FLOAT);
+ DEF_FACTORY_ONE_TYPE(OpGather, FP32);
break;
case Op_SCATTER:
DEF_FACTORY_ONE_TYPE(OpScatter, INT8);
DEF_FACTORY_ONE_TYPE(OpScatter, INT16);
DEF_FACTORY_ONE_TYPE(OpScatter, INT32);
DEF_FACTORY_ONE_TYPE(OpScatter, FP16);
- DEF_FACTORY_ONE_TYPE(OpScatter, FLOAT);
+ DEF_FACTORY_ONE_TYPE(OpScatter, FP32);
break;
// image
@@ -397,7 +397,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT48);
DEF_FACTORY_TWO_TYPE_RESIZE_INT16(OpResize, INT16, INT16);
DEF_FACTORY_TWO_TYPE_RESIZE_FP16(OpResize, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OpResize, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OpResize, FP32, FP32);
break;
// data_nodes
@@ -405,7 +405,7 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
return new OpConst(sgt, id);
case Op_IDENTITY:
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
@@ -421,23 +421,23 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
- DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
break;
case Op_RESCALE:
DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index 25dfc6e..b525e69 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -111,7 +111,7 @@
return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \
}
-#define DEF_FACTORY_TWO_TYPE_RESIZE_FLOAT(OP, DTYPE1, DTYPE2) \
+#define DEF_FACTORY_TWO_TYPE_RESIZE_FP32(OP, DTYPE1, DTYPE2) \
if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
{ \
return new OP<DType_##DTYPE1, DType_##DTYPE2, float>(sgt, attribute, id); \
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
index 03ee660..eccba09 100644
--- a/reference_model/src/ops/reduction.cc
+++ b/reference_model/src/ops/reduction.cc
@@ -159,20 +159,20 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP16);
-DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP16);
-DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP16);
-DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP16);
-DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSumInt, INT32);
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
index 25174bd..b6c4043 100644
--- a/reference_model/src/ops/scatter_gather.cc
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -227,10 +227,10 @@ DEF_INSTANTIATE_ONE_TYPE(OpGather, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpGather, INT16);
DEF_INSTANTIATE_ONE_TYPE(OpGather, INT32);
DEF_INSTANTIATE_ONE_TYPE(OpGather, FP16);
-DEF_INSTANTIATE_ONE_TYPE(OpGather, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE(OpGather, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT16);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, INT32);
DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP16);
-DEF_INSTANTIATE_ONE_TYPE(OpScatter, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE(OpScatter, FP32);
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
index 9511c31..3de4899 100644
--- a/reference_model/src/ops/template_types.h
+++ b/reference_model/src/ops/template_types.h
@@ -65,7 +65,7 @@ using Tensor6 = TensorTemplate<ETensor6<T>>;
template <DType type>
struct GetEigenType;
template <>
-struct GetEigenType<DType_FLOAT>
+struct GetEigenType<DType_FP32>
{
using type = float;
};
@@ -301,9 +301,9 @@ struct GetAccDType<DType_FP16, DType_FP16>
static constexpr DType value = DType_FP16;
};
template <>
-struct GetAccDType<DType_FLOAT, DType_FLOAT>
+struct GetAccDType<DType_FP32, DType_FP32>
{
- static constexpr DType value = DType_FLOAT;
+ static constexpr DType value = DType_FP32;
};
}; // namespace TosaReference
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index c617dda..7db5182 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -512,7 +512,7 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
.contract(div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width }), contract_dims)
.reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
.broadcast(bcast);
- if (Dtype != DType_FLOAT && Dtype != DType_FP16)
+ if (Dtype != DType_FP32 && Dtype != DType_FP16)
{
try
{
@@ -1679,41 +1679,41 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval()
// template explicit instantiation
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP16);
-DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FLOAT);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FLOAT, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP32);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP32, FP32);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32);
// [in_t, weight_t, acc_t]
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP32, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP32, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT16, INT8, INT48);
@@ -1721,17 +1721,17 @@ DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT16, INT8, INT48);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48);
DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FLOAT);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FLOAT, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP32);
+DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP32, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
-DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FLOAT);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FLOAT, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP32, FP32, FP32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32);
DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT16, INT8, INT48);
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
index 50e710a..f51c38c 100644
--- a/reference_model/src/ops/type_conversion.cc
+++ b/reference_model/src/ops/type_conversion.cc
@@ -312,7 +312,7 @@ CastHelper<DType_FP16, OutDtype>::CastHelper()
}
template <DType InDtype>
-CastHelper<InDtype, DType_FLOAT>::CastHelper()
+CastHelper<InDtype, DType_FP32>::CastHelper()
{
fcn = [](InEigenType in) -> float {
float out = (OutEigenType)in; // default cast to float is round_to_nearest_float()
@@ -321,7 +321,7 @@ CastHelper<InDtype, DType_FLOAT>::CastHelper()
}
template <DType OutDtype>
-CastHelper<DType_FLOAT, OutDtype>::CastHelper()
+CastHelper<DType_FP32, OutDtype>::CastHelper()
{
fcn = [](float in) -> OutEigenType {
OutEigenType out = std::round(in);
@@ -339,23 +339,23 @@ DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FP32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT16);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP16, INT32);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
-DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FP32, INT32);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT8);
DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT8, INT16);
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
index 5f197cf..b0de30c 100644
--- a/reference_model/src/ops/type_conversion.h
+++ b/reference_model/src/ops/type_conversion.h
@@ -137,11 +137,11 @@ private:
};
template <DType InDtype>
-class CastHelper<InDtype, DType_FLOAT>
+class CastHelper<InDtype, DType_FP32>
{
public:
using InEigenType = typename GetEigenType<InDtype>::type;
- using OutEigenType = typename GetEigenType<DType_FLOAT>::type;
+ using OutEigenType = typename GetEigenType<DType_FP32>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
CastHelper();
const FcnType& get_fcn() const
@@ -154,10 +154,10 @@ private:
};
template <DType OutDtype>
-class CastHelper<DType_FLOAT, OutDtype>
+class CastHelper<DType_FP32, OutDtype>
{
public:
- using InEigenType = typename GetEigenType<DType_FLOAT>::type;
+ using InEigenType = typename GetEigenType<DType_FP32>::type;
using OutEigenType = typename GetEigenType<OutDtype>::type;
using FcnType = std::function<OutEigenType(InEigenType)>;
static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
index cbb7001..15d82e6 100644
--- a/reference_model/src/subgraph_traverser.cc
+++ b/reference_model/src/subgraph_traverser.cc
@@ -394,7 +394,7 @@ int SubgraphTraverser::allocateTensor()
tensor->setTensorValueFloat(f16_data.size(), f16_data.data());
}
break;
- case DType_FLOAT:
+ case DType_FP32:
{
std::vector<float> fp32_data;
TosaSerializationHandler::ConvertU8toF32(ts->GetData(), tensor->getElementCount(), fp32_data);
@@ -702,7 +702,7 @@ int SubgraphTraverser::validateGraph()
DType dtype = currTensor->getDtype();
// Float-point disallowed
- if (dtype == DType_FLOAT || dtype == DType_FP16)
+ if (dtype == DType_FP32 || dtype == DType_FP16)
{
WARNING("SubgraphTraverser::validateGraph(): TOSA Base Inference profile selected: All floating point "
"disabled, but %s tensor %s found\n",
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index cbe12a9..8e65a27 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -93,7 +93,7 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
switch (getDtype())
{
- case DType_FLOAT:
+ case DType_FP32:
fdatabuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(fdatabuf);
@@ -161,8 +161,8 @@ int TosaReference::Tensor::readFromNpyFile(const char* filename)
for (uint32_t i=0; i < elements; i++) {
fdatabuf[i] = half_float::half_cast<float, half_float::half>(f16databuf[i]);
}
- // Fall through to DType_FLOAT case
- case DType_FLOAT:
+ // Fall through to DType_FP32 case
+ case DType_FP32:
if (setTensorValueFloat(elements, fdatabuf))
{
if (f16databuf)
@@ -229,7 +229,7 @@ int TosaReference::Tensor::writeToNpyFile(const char* filename) const
switch (getDtype())
{
- case DType_FLOAT:
+ case DType_FP32:
fdatabuf = (float*)calloc(sizeof(float), elements);
ASSERT_MEM(fdatabuf);
@@ -429,7 +429,7 @@ int TosaReference::Tensor::readfromVector(const std::vector<float>& vals)
uint32_t elements = getElementCount();
switch (getDtype())
{
- case DType_FLOAT:
+ case DType_FP32:
if (vals.size() != elements)
{
WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
@@ -532,7 +532,7 @@ int TosaReference::Tensor::writeToVector(std::vector<float>& vals)
switch (getDtype())
{
- case DType_FLOAT:
+ case DType_FP32:
if (vals.size() != elements)
{
WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index 78a210e..efd7e62 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -642,7 +642,7 @@ public:
{
switch (tensorDtype_)
{
- case DType_FLOAT:
+ case DType_FP32:
case DType_FP16:
switch (rank)
{
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject 485a11d8cb67c8062c632f0987cd31cedbe93d6
+Subproject e1072a9ed871fd474e7b09b7a74ae7be5f0a6f7
diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py
index e0c6cf0..791fbf7 100644
--- a/verif/generator/tosa_arg_gen.py
+++ b/verif/generator/tosa_arg_gen.py
@@ -776,7 +776,7 @@ class TosaTensorValuesGen:
), "Op.MUL must have 2 placeholders, 0 consts"
tens = []
- if dtypeList[0] in (DType.FP16, DType.FLOAT):
+ if dtypeList[0] in (DType.FP16, DType.FP32):
tens.extend(testGen.buildPlaceholderTensors(shapeList[:], dtypeList[:]))
else:
placeholders = []
@@ -1106,10 +1106,10 @@ class TosaArgGen:
@staticmethod
def agFullyConnected(testGen, opName, shapeList, dtypes, error_name=None):
- if isinstance(dtypes, list) or isinstance(dtypes, tuple):
- input_dtype = dtypes[0]
- else:
- input_dtype = dtypes
+ assert isinstance(dtypes, list) or isinstance(
+ dtypes, tuple
+ ), f"{dtypes} unexpected"
+ input_dtype = dtypes[0]
if error_name == ErrorIf.WrongOutputType:
accum_dtype = get_wrong_output_type(opName, testGen.rng, input_dtype)
@@ -1129,9 +1129,9 @@ class TosaArgGen:
elif dtype == DType.INT16:
accum_dtypes = [DType.INT48]
elif dtype == DType.FP16:
- accum_dtypes = [DType.FP16, DType.FLOAT]
- elif dtype == DType.FLOAT:
- accum_dtypes = [DType.FLOAT]
+ accum_dtypes = [DType.FP16, DType.FP32]
+ elif dtype == DType.FP32:
+ accum_dtypes = [DType.FP32]
elif error_name is None:
assert False, f"Invalid I/O DType for MatMul: {DTypeNames[dtype]}"
@@ -1245,7 +1245,7 @@ class TosaArgGen:
if dtype in [DType.BOOL, DType.INT8, DType.INT16, DType.INT32]:
pad_const_int = testGen.getRandNumberDType(dtype)
pad_const_fp = 0
- elif dtype in (DType.FP16, DType.FLOAT):
+ elif dtype in (DType.FP16, DType.FP32):
pad_const_int = 0
pad_const_fp = testGen.getRandNumberDType(dtype)
else:
@@ -1303,9 +1303,9 @@ class TosaArgGen:
elif dtype == DType.INT8 or dtype == DType.INT16:
accum_dtypes = [DType.INT32]
elif dtype == DType.FP16:
- accum_dtypes = [DType.FP16, DType.FLOAT]
- elif dtype == DType.FLOAT:
- accum_dtypes = [DType.FLOAT]
+ accum_dtypes = [DType.FP16, DType.FP32]
+ elif dtype == DType.FP32:
+ accum_dtypes = [DType.FP32]
elif error_name is None:
assert False, f"Invalid I/O DType for pooling: {DTypeNames[dtype]}"
else:
@@ -1408,20 +1408,20 @@ class TosaArgGen:
if error_name == ErrorIf.WrongOutputType:
dtypeList = TosaErrorIfArgGen.eiCastErrorIf(testGen, inDtype)
elif inDtype == DType.INT8:
- dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT16, DType.INT32, DType.FP32]
elif inDtype == DType.INT16:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT32, DType.FP32]
elif inDtype == DType.INT32:
- dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
elif inDtype == DType.BOOL:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif inDtype == DType.FP16:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
- elif inDtype == DType.FLOAT:
+ elif inDtype == DType.FP32:
dtypeList = [DType.INT8, DType.INT16, DType.INT32]
elif error_name == ErrorIf.WrongInputType:
# Pick some potentially correct output type for incorrect input type
- dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FLOAT]
+ dtypeList = [DType.BOOL, DType.INT8, DType.INT16, DType.FP32]
else:
raise Exception("Unexpected input dtype: {}".format(inDtype))
@@ -1826,8 +1826,8 @@ class TosaArgGen:
outputDTypeList = [DType.INT48]
elif dtype == DType.FP16:
outputDTypeList = [DType.FP16]
- elif dtype == DType.FLOAT:
- outputDTypeList = [DType.FLOAT]
+ elif dtype == DType.FP32:
+ outputDTypeList = [DType.FP32]
elif error_name == ErrorIf.WrongInputType:
# If an incorrect input type is used then we set a 'correct'
# output type to avoid other errors
diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py
index a766803..abe1a97 100644
--- a/verif/generator/tosa_error_if.py
+++ b/verif/generator/tosa_error_if.py
@@ -119,7 +119,7 @@ class TosaErrorIfArgGen:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
)
elif mode == ResizeMode.NEAREST and dtype == DType.INT16:
@@ -128,7 +128,7 @@ class TosaErrorIfArgGen:
DType.INT8,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
)
elif mode == ResizeMode.BILINEAR and dtype == DType.INT8:
@@ -137,7 +137,7 @@ class TosaErrorIfArgGen:
DType.INT8,
DType.INT16,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
)
elif mode == ResizeMode.BILINEAR and dtype == DType.INT16:
@@ -146,7 +146,7 @@ class TosaErrorIfArgGen:
DType.INT8,
DType.INT16,
DType.INT32,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
)
elif dtype == DType.FP16:
@@ -156,9 +156,9 @@ class TosaErrorIfArgGen:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
)
- elif dtype == DType.FLOAT:
+ elif dtype == DType.FP32:
incorrect_types = (
DType.INT4,
DType.INT8,
@@ -299,8 +299,8 @@ class TosaErrorIfArgGen:
@staticmethod
def eiCastErrorIf(testGen, input_dtype):
- if input_dtype in [DType.BOOL, DType.FP16, DType.FLOAT]:
- outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FLOAT]
+ if input_dtype in [DType.BOOL, DType.FP16, DType.FP32]:
+ outputDType = [DType.BOOL, DType.INT48, DType.FP16, DType.FP32]
elif input_dtype in [DType.INT8, DType.INT16, DType.INT32]:
outputDType = [DType.INT48]
else:
@@ -366,6 +366,16 @@ class TosaErrorValidator:
}
wrong_input_dtypes = list(usableDTypes(excludes=allowed_input_dtypes))
+ # Turn the wrong dtypes into required list of types
+ if op["op"] in [
+ Op.FULLY_CONNECTED,
+ Op.CONV2D,
+ Op.CONV3D,
+ Op.DEPTHWISE_CONV2D,
+ Op.TRANSPOSE_CONV2D,
+ ]:
+ wrong_input_dtypes = [[t, t, t] for t in wrong_input_dtypes]
+
if op["op"] == Op.CLAMP:
wrong_input_dtypes.remove(DType.INT48)
@@ -415,7 +425,7 @@ class TosaErrorValidator:
and output_dtype != DType.INT48
)
or (input_dtype == DType.FP16 and output_dtype != DType.FP16)
- or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
+ or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
):
error_result = True
@@ -430,28 +440,28 @@ class TosaErrorValidator:
or (input_dtype == DType.INT16 and output_dtype != DType.INT48)
or (
input_dtype == DType.FP16
- and output_dtype not in (DType.FP16, DType.FLOAT)
+ and output_dtype not in (DType.FP16, DType.FP32)
)
- or (input_dtype == DType.FLOAT and output_dtype != DType.FLOAT)
+ or (input_dtype == DType.FP32 and output_dtype != DType.FP32)
):
error_result = True
elif op["op"] == Op.ARGMAX:
if (
- input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
+ input_dtype in [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
and output_dtype != DType.INT32
):
error_result = True
elif op["op"] == Op.MUL:
if (
- input_dtype not in (DType.FP16, DType.FLOAT)
+ input_dtype not in (DType.FP16, DType.FP32)
and output_dtype != DType.INT32
):
error_result = True
elif input_dtype == DType.FP16 and output_dtype != DType.FP16:
error_result = True
- elif input_dtype == DType.FLOAT and output_dtype != DType.FLOAT:
+ elif input_dtype == DType.FP32 and output_dtype != DType.FP32:
error_result = True
elif op["op"] == Op.TABLE:
@@ -477,7 +487,7 @@ class TosaErrorValidator:
DType.BOOL,
DType.INT16,
DType.INT32,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
]
)
@@ -488,7 +498,7 @@ class TosaErrorValidator:
DType.BOOL,
DType.INT8,
DType.INT32,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
]
)
@@ -499,7 +509,7 @@ class TosaErrorValidator:
DType.BOOL,
DType.INT8,
DType.INT16,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
]
)
@@ -508,7 +518,7 @@ class TosaErrorValidator:
and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
)
or (
- input_dtype == DType.FLOAT
+ input_dtype == DType.FP32
and output_dtype not in [DType.INT8, DType.INT16, DType.INT32]
)
):
@@ -526,9 +536,9 @@ class TosaErrorValidator:
or input_dtype == DType.INT16
and output_dtype != DType.INT48
or input_dtype == DType.FP16
- and output_dtype not in (DType.FP16, DType.FLOAT)
- or input_dtype == DType.FLOAT
- and output_dtype != DType.FLOAT
+ and output_dtype not in (DType.FP16, DType.FP32)
+ or input_dtype == DType.FP32
+ and output_dtype != DType.FP32
):
error_result = True
# invalid input types are ignored, to avoid reporting multiple errors
@@ -2306,12 +2316,12 @@ class TosaInvalidValidator:
not (input_dtype == DType.INT8 and output_dtype == DType.INT32)
and not (input_dtype == DType.INT16 and output_dtype == DType.INT48)
and not (input_dtype == DType.FP16 and output_dtype == DType.FP16)
- and not (input_dtype == DType.FLOAT and output_dtype == DType.FLOAT)
+ and not (input_dtype == DType.FP32 and output_dtype == DType.FP32)
)
elif mode == ResizeMode.NEAREST:
# Invalid output data type / Invalid input datatype
return (input_dtype != output_dtype) or (
- input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
+ input_dtype not in [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
)
else:
# Invalid resize mode
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index 9ff6ec5..78d86cd 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -13,6 +13,7 @@ from generator.tosa_error_if import ErrorIf
from generator.tosa_error_if import TosaErrorIfArgGen
from generator.tosa_error_if import TosaErrorValidator
from generator.tosa_error_if import TosaInvalidValidator
+from generator.tosa_utils import DTYPE_ATTRIBUTES
from generator.tosa_utils import MAX_RESIZE_DIMENSION
from generator.tosa_utils import usableDTypes
from tosa.DType import DType
@@ -83,7 +84,7 @@ class TosaTestGen:
)
elif dtype == DType.FP16:
return np.float16(self.rng.random(size=shape))
- elif dtype == DType.FLOAT:
+ elif dtype == DType.FP32:
return np.float32(self.rng.random(size=shape))
else:
raise Exception("Unrecognized Dtype: {}".format(dtype))
@@ -128,7 +129,7 @@ class TosaTestGen:
return np.int32(self.rng.integers(low=low, high=high, size=1))[0]
def getRandNumberDType(self, dtype):
- if dtype == DType.FLOAT:
+ if dtype == DType.FP32:
return self.rng.random()
elif dtype == DType.FP16:
rand_f32 = self.rng.random()
@@ -162,58 +163,26 @@ class TosaTestGen:
return "x".join(sStr)
- def typeStr(self, t):
- if isinstance(t, list):
- assert len(t) >= 2
- return "{}x{}".format(self.typeStr(t[0]), self.typeStr(t[1]))
+ def typeStr(self, dtype):
+ if isinstance(dtype, list) or isinstance(dtype, tuple):
+ assert len(dtype) >= 2
+ strs = [self.typeStr(t) for t in dtype]
+ # Limit types to the first 2 as the 3rd is the accumulator
+ return "x".join(strs[:2])
else:
- if t == DType.BOOL:
- return "b"
- elif t == DType.INT4:
- return "i4"
- elif t == DType.INT8:
- return "i8"
- elif t == DType.UINT8:
- return "u8"
- elif t == DType.INT16:
- return "i16"
- elif t == DType.UINT16:
- return "u16"
- elif t == DType.INT32:
- return "i32"
- elif t == DType.INT48:
- return "i48"
- elif t == DType.FP16:
- return "f16"
- elif t == DType.FLOAT:
- return "float"
+ if dtype in DTYPE_ATTRIBUTES:
+ return DTYPE_ATTRIBUTES[dtype]["str"]
else:
- raise Exception("Unknown dtype, cannot convert to string: {}".format(t))
+ raise Exception(
+ "Unknown dtype, cannot convert to string: {}".format(dtype)
+ )
- def typeWidth(self, t):
+ def typeWidth(self, dtype):
"""Get the datatype width for data types"""
- if t == DType.INT4:
- return 4
- elif t == DType.INT8:
- return 8
- elif t == DType.UINT8:
- return 8
- elif t == DType.INT16:
- return 16
- elif t == DType.UINT16:
- return 16
- elif t == DType.INT32:
- return 32
- elif t == DType.INT48:
- return 48
- elif t == DType.FP16:
- return 16
- elif t == DType.FLOAT:
- return 32
- elif t == DType.BOOL:
- return 1
+ if dtype in DTYPE_ATTRIBUTES:
+ return DTYPE_ATTRIBUTES[dtype]["width"]
else:
- raise Exception(f"Unknown dtype, cannot determine width: {t}")
+ raise Exception(f"Unknown dtype, cannot determine width: {dtype}")
# Argument generators
# Returns a list of tuples (stringDescriptor, [build_fcn_arg_list])
@@ -355,7 +324,7 @@ class TosaTestGen:
# Special for multiply:
# Force the result to INT32 for INT types
- if a.dtype not in (DType.FP16, DType.FLOAT):
+ if a.dtype not in (DType.FP16, DType.FP32):
result_tens.setDtype(DType.INT32)
if error_name == ErrorIf.WrongOutputType:
all_dtypes = [DType.INT8, DType.INT16, DType.INT48]
@@ -1074,7 +1043,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- if a.dtype in (DType.FP16, DType.FLOAT):
+ if a.dtype in (DType.FP16, DType.FP32):
attr.ClampAttribute(0, 0, min_val, max_val)
else:
attr.ClampAttribute(min_val, max_val, 0, 0)
@@ -1086,7 +1055,7 @@ class TosaTestGen:
result_tens = OutputShaper.unaryOp(self.ser, self.rng, a, error_name)
attr = ts.TosaSerializerAttribute()
- attr.LeakyReluAttribute(self.getRandNumberDType(DType.FLOAT))
+ attr.LeakyReluAttribute(self.getRandNumberDType(DType.FP32))
self.ser.addOperator(op["op"], [a.name], [result_tens.name], attr)
return result_tens
@@ -1890,7 +1859,7 @@ class TosaTestGen:
op["op"], [cond_tens.name, a.name, b.name], [result_tens.name], attr
)
- if a.dtype in (DType.FLOAT, DType.FP16, DType.INT32):
+ if a.dtype in (DType.FP32, DType.FP16, DType.INT32):
then_op, else_op = Op.ADD, Op.SUB
elif a.dtype in (DType.INT8, DType.INT16):
then_op, else_op = Op.LOGICAL_RIGHT_SHIFT, Op.LOGICAL_LEFT_SHIFT
@@ -2001,7 +1970,7 @@ class TosaTestGen:
if error_name == ErrorIf.CondGraphOutputNotMatchingBool:
cond_tens = self.ser.addOutput(
- [], self.rng.choice([DType.INT8, DType.INT32, DType.FLOAT])
+ [], self.rng.choice([DType.INT8, DType.INT32, DType.FP32])
)
else:
cond_tens = self.ser.addOutput([], DType.BOOL)
@@ -2429,7 +2398,7 @@ class TosaTestGen:
# if not specified, defaults to (1, 4)
# 'build_fcn': tuple of the function to (build_operator(), TensorGen function, ArgGen enum)
# 'types': array of datatypes to be tested
- TYPE_FP = [DType.FLOAT, DType.FP16]
+ TYPE_FP = [DType.FP32, DType.FP16]
TYPE_INT = [DType.INT8, DType.INT16, DType.INT32] # Excludes INT4
TYPE_INT_FP = [
@@ -2437,30 +2406,31 @@ class TosaTestGen:
DType.INT16,
DType.INT32,
DType.FP16,
- DType.FLOAT,
+ DType.FP32,
] # Excludes INT4
TYPE_BOOL = [DType.BOOL]
- TYPE_FI32 = [DType.FLOAT, DType.FP16, DType.INT32] # floating-types and INT32
+ TYPE_FI32 = [DType.FP32, DType.FP16, DType.INT32] # floating-types and INT32
TYPE_FIB = [
DType.FP16,
- DType.FLOAT,
+ DType.FP32,
DType.INT8,
DType.INT16,
DType.INT32,
DType.BOOL,
]
- TYPE_FI16 = [DType.FLOAT, DType.INT16]
+ TYPE_FI16 = [DType.FP32, DType.INT16]
- TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FLOAT]
+ TYPE_NARROW_INT_FP = [DType.INT8, DType.INT16, DType.FP16, DType.FP32]
+ # List of [Input Type 1, Input Type 2, Accumulator Type]
TYPE_CONV = [
[DType.INT8, DType.INT4, DType.INT32],
[DType.INT8, DType.INT8, DType.INT32],
[DType.INT16, DType.INT8, DType.INT48],
[DType.FP16, DType.FP16, DType.FP16],
- [DType.FP16, DType.FP16, DType.FLOAT],
- DType.FLOAT,
+ [DType.FP16, DType.FP16, DType.FP32],
+ [DType.FP32, DType.FP32, DType.FP32],
]
DEFAULT_RANK_RANGE = (1, TOSA_TENSOR_MAX_RANK)
@@ -3478,7 +3448,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgReduceSum,
TosaArgGen.agAxis,
),
- "types": (DType.FP16, DType.FLOAT, DType.INT32),
+ "types": (DType.FP16, DType.FP32, DType.INT32),
"error_if_validators": (
TosaErrorValidator.evAxisLargerRank,
TosaErrorValidator.evAxisSmallerZero,
@@ -3665,7 +3635,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
None,
),
- "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FLOAT),
+ "types": (DType.INT8, DType.INT16, DType.INT32, DType.FP16, DType.FP32),
"error_if_validators": (
TosaErrorValidator.evWrongInputType,
TosaErrorValidator.evWrongOutputType,
@@ -3706,7 +3676,7 @@ class TosaTestGen:
TosaTensorValuesGen.tvgDefault,
TosaArgGen.agResize,
),
- "types": (DType.INT8, DType.INT16, DType.FP16, DType.FLOAT),
+ "types": (DType.INT8, DType.INT16, DType.FP16, DType.FP32),
"invalid_test_validators": (
TosaInvalidValidator.ivWrongDataTypeOrModeResize,
),
@@ -3742,7 +3712,7 @@ class TosaTestGen:
),
"types": (
DType.FP16,
- DType.FLOAT,
+ DType.FP32,
DType.INT8,
DType.INT16,
DType.INT32,
@@ -3872,7 +3842,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3901,7 +3871,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3929,7 +3899,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -3958,7 +3928,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
outputDType = rng.choice(wrong_dtypes)
else:
@@ -3984,7 +3954,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4016,7 +3986,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([DType.INT32]))
outputDType = rng.choice(wrong_dtypes)
@@ -4069,7 +4039,7 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
- excludes = [DType.FP16, DType.FLOAT]
+ excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(usableDTypes(excludes=excludes))
@@ -4131,7 +4101,7 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
- excludes = [DType.FP16, DType.FLOAT]
+ excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(usableDTypes(excludes=excludes))
@@ -4182,7 +4152,7 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
- excludes = [DType.FP16, DType.FLOAT]
+ excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(usableDTypes(excludes=excludes))
@@ -4217,7 +4187,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
]
wrong_dtypes = list(set(all_dtypes) - set([ifm.dtype]))
@@ -4255,7 +4225,7 @@ class OutputShaper:
DType.INT8,
DType.INT16,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
)
elif a.dtype == DType.INT16:
incorrect_types = (
@@ -4263,9 +4233,9 @@ class OutputShaper:
DType.INT8,
DType.INT16,
DType.INT32,
- DType.FLOAT,
+ DType.FP32,
)
- elif a.dtype == DType.FLOAT or a.dtype == DType.FP16:
+ elif a.dtype == DType.FP32 or a.dtype == DType.FP16:
incorrect_types = (
DType.INT4,
DType.INT8,
@@ -4307,7 +4277,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
}
wrong_dtypes = list(all_dtypes - set([input1.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4334,7 +4304,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
@@ -4358,7 +4328,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4376,7 +4346,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4412,7 +4382,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4440,7 +4410,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([a.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4464,7 +4434,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([values.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4491,7 +4461,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes = list(set(all_dtypes) - set([values_in.dtype]))
outputDType = rng.choice(wrong_dtypes)
@@ -4512,7 +4482,7 @@ class OutputShaper:
DType.INT16,
DType.INT32,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
]
wrong_dtypes.remove(output_dtype)
output_dtype = rng.choice(wrong_dtypes)
@@ -4619,7 +4589,7 @@ class OutputShaper:
if error_name == ErrorIf.WrongOutputType:
if ifm.dtype == DType.FP16:
- excludes = [DType.FP16, DType.FLOAT]
+ excludes = [DType.FP16, DType.FP32]
else:
excludes = [out_dtype]
wrong_dtypes = list(usableDTypes(excludes=excludes))
diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py
index 7fa31e7..104d9bb 100644
--- a/verif/generator/tosa_utils.py
+++ b/verif/generator/tosa_utils.py
@@ -5,6 +5,19 @@ from tosa.DType import DType
# Maximum dimension size for output and inputs for RESIZE
MAX_RESIZE_DIMENSION = 16384
+DTYPE_ATTRIBUTES = {
+ DType.BOOL: {"str": "b", "width": 1},
+ DType.INT4: {"str": "i4", "width": 4},
+ DType.INT8: {"str": "i8", "width": 8},
+ DType.UINT8: {"str": "u8", "width": 8},
+ DType.INT16: {"str": "i16", "width": 16},
+ DType.UINT16: {"str": "u16", "width": 16},
+ DType.INT32: {"str": "i32", "width": 32},
+ DType.INT48: {"str": "i48", "width": 48},
+ DType.FP16: {"str": "f16", "width": 16},
+ DType.FP32: {"str": "f32", "width": 32},
+}
+
def valueToName(item, value):
"""Get the name of an attribute with the given value.
@@ -88,10 +101,8 @@ def product(shape):
def get_accum_dtype_from_tgTypes(dtypes):
# Get accumulate data-type from the test generator's defined types
- if isinstance(dtypes, list) or isinstance(dtypes, tuple):
- return dtypes[-1]
- else:
- return dtypes
+ assert isinstance(dtypes, list) or isinstance(dtypes, tuple)
+ return dtypes[-1]
def get_wrong_output_type(op_name, rng, input_dtype):
@@ -102,7 +113,7 @@ def get_wrong_output_type(op_name, rng, input_dtype):
DType.INT8,
DType.INT16,
DType.INT48,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
)
elif input_dtype == DType.INT16:
@@ -111,10 +122,10 @@ def get_wrong_output_type(op_name, rng, input_dtype):
DType.INT8,
DType.INT16,
DType.INT32,
- DType.FLOAT,
+ DType.FP32,
DType.FP16,
)
- elif input_dtype == DType.FLOAT or input_dtype == DType.FP16:
+ elif input_dtype == DType.FP32 or input_dtype == DType.FP16:
incorrect_types = (
DType.INT4,
DType.INT8,
diff --git a/verif/tests/test_tosa_refmodel.py b/verif/tests/test_tosa_refmodel.py
index fbe3a7d..b608fd8 100644
--- a/verif/tests/test_tosa_refmodel.py
+++ b/verif/tests/test_tosa_refmodel.py
@@ -45,7 +45,7 @@ REF_MODEL_TYPE_TO_OUT = {
"uint8": "u8",
"int16": "i16",
"int32": "i32",
- "float": "float",
+ "fp32": "f32",
"fp16": "f16",
}
@@ -123,21 +123,21 @@ class BuildTosaTest:
# Tests - op_name, ref_model_type, num_expected_tests
TEST_PARAMS = [
("add", "int32", 1),
- ("add", "float", 1),
+ ("add", "fp32", 1),
("abs", "int32", 1),
- ("abs", "float", 1),
+ ("abs", "fp32", 1),
("abs", "fp16", 1),
("negate", "int8", 1),
("negate", "int16", 1),
("negate", "int32", 1),
- ("negate", "float", 1),
+ ("negate", "fp32", 1),
("negate", "fp16", 1),
# One test per axis (shape dimensions)
("concat", "bool", SHAPE_DIMS),
("concat", "int8", SHAPE_DIMS),
("concat", "int16", SHAPE_DIMS),
("concat", "int32", SHAPE_DIMS),
- ("concat", "float", SHAPE_DIMS),
+ ("concat", "fp32", SHAPE_DIMS),
("concat", "fp16", SHAPE_DIMS),
]