aboutsummaryrefslogtreecommitdiff
path: root/reference_model/src/ops/data_layout.cc
diff options
context:
space:
mode:
authorTai Ly <tai.ly@arm.com>2023-03-28 22:06:56 +0000
committerTai Ly <tai.ly@arm.com>2023-05-05 19:23:15 +0000
commita4d748b08accce06fab93e2d2b96e499b35ae89b (patch)
tree20a3957e1f45f65f35d5d67ecce1618659e388f0 /reference_model/src/ops/data_layout.cc
parent0c71686875618b2e11290273b7a05b88ef8a8aae (diff)
downloadreference_model-a4d748b08accce06fab93e2d2b96e499b35ae89b.tar.gz
[reference model] Add precise mode
This adds --precise_mode=1 option to tosa_referece_model, which will cause reference model to convert all floating point tensors to FP64 tensors and compute all operators accordingly. Also adds optional -p arguments to test runners tosa_verif_run_tests.py and tosa_verif_framework_compiler_runner.py to run tests in precise mode Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I156055216ad61710096497a8fa1a653be2a602a3
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r--reference_model/src/ops/data_layout.cc93
1 files changed, 51 insertions, 42 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
index a189466..442cef8 100644
--- a/reference_model/src/ops/data_layout.cc
+++ b/reference_model/src/ops/data_layout.cc
@@ -20,7 +20,7 @@ using namespace TosaReference;
using namespace Eigen;
using namespace tosa;
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -32,14 +32,14 @@ OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Axis);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpConcat<Rank, Dtype>::~OpConcat()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpConcat<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -100,7 +100,7 @@ int OpConcat<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpConcat<Rank, Dtype>::eval()
{
@@ -124,7 +124,7 @@ int OpConcat<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -136,12 +136,12 @@ OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Pad);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpPad<Rank, Dtype>::~OpPad()
{
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpPad<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -185,22 +185,23 @@ int OpPad<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpPad<Rank, Dtype>::eval()
{
InEigenType pad_value = 0;
switch (Dtype)
{
- case DType_BOOL:
- case DType_INT8:
- case DType_INT16:
- case DType_INT32:
+ case TOSA_REF_TYPE_BOOL:
+ case TOSA_REF_TYPE_INT8:
+ case TOSA_REF_TYPE_INT16:
+ case TOSA_REF_TYPE_INT32:
pad_value = (InEigenType)attribute->pad_const_int();
break;
- case DType_FP16:
- case DType_BF16:
- case DType_FP32:
+ case TOSA_REF_TYPE_FP16:
+ case TOSA_REF_TYPE_BF16:
+ case TOSA_REF_TYPE_FP32:
+ case TOSA_REF_TYPE_FP64:
pad_value = (InEigenType)attribute->pad_const_fp();
break;
default:
@@ -213,7 +214,7 @@ int OpPad<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -225,14 +226,14 @@ OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Reshape);
}
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
OpReshape<InRank, OutRank, Dtype>::~OpReshape()
{
if (attribute)
delete attribute;
}
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -270,7 +271,7 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int InRank, int OutRank, DType Dtype>
+template <int InRank, int OutRank, TOSA_REF_TYPE Dtype>
int OpReshape<InRank, OutRank, Dtype>::eval()
{
for (int32_t d = 0; d < OutRank; d++)
@@ -313,7 +314,7 @@ int OpReshape<InRank, OutRank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -325,14 +326,14 @@ OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Axis);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpReverse<Rank, Dtype>::~OpReverse()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReverse<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -376,7 +377,7 @@ int OpReverse<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpReverse<Rank, Dtype>::eval()
{
out->getTensor() = in->getTensor().reverse(reverse_array);
@@ -384,7 +385,7 @@ int OpReverse<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -396,14 +397,14 @@ OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Slice);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpSlice<Rank, Dtype>::~OpSlice()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSlice<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -449,7 +450,7 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpSlice<Rank, Dtype>::eval()
{
out->getTensor() = in->getTensor().slice(begin_array, size_array);
@@ -457,7 +458,7 @@ int OpSlice<Rank, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -469,14 +470,14 @@ OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Tile);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpTileBase<Rank, Dtype>::~OpTileBase()
{
if (attribute)
delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTileBase<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -518,14 +519,14 @@ int OpTileBase<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTile<Rank, Dtype>::eval()
{
// primary template shouldn't be called
- FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
+ FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNameTOSAREFTYPE(Dtype));
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<1, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -537,7 +538,7 @@ int OpTile<1, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<2, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -553,7 +554,7 @@ int OpTile<2, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<3, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -573,7 +574,7 @@ int OpTile<3, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<4, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -597,7 +598,7 @@ int OpTile<4, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<5, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -626,7 +627,7 @@ int OpTile<5, Dtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype>
+template <TOSA_REF_TYPE Dtype>
int OpTile<6, Dtype>::eval()
{
for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
@@ -659,7 +660,7 @@ int OpTile<6, Dtype>::eval()
return GraphNode::eval();
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
@@ -671,13 +672,13 @@ OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Transpose);
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
OpTranspose<Rank, Dtype>::~OpTranspose()
{
if (attribute) delete attribute;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTranspose<Rank, Dtype>::checkTensorAttributes()
{
// Check Tosa Level
@@ -727,7 +728,7 @@ int OpTranspose<Rank, Dtype>::checkTensorAttributes()
return 0;
}
-template <int Rank, DType Dtype>
+template <int Rank, TOSA_REF_TYPE Dtype>
int OpTranspose<Rank, Dtype>::eval()
{
out->getTensor() = in->getTensor().shuffle(perm_array);
@@ -743,6 +744,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64)
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16);
@@ -751,6 +753,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64);
DEF_INSTANTIATE_RESHAPE(OpReshape, FP16);
DEF_INSTANTIATE_RESHAPE(OpReshape, BF16);
@@ -759,6 +762,7 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
+DEF_INSTANTIATE_RESHAPE(OpReshape, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16);
@@ -767,6 +771,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16);
@@ -775,6 +780,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16);
@@ -783,6 +789,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16);
@@ -791,6 +798,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16);
@@ -799,3 +807,4 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64);