diff options
Diffstat (limited to 'reference_model/src/ops/data_layout.cc')
-rw-r--r-- | reference_model/src/ops/data_layout.cc | 93 |
1 files changed, 51 insertions, 42 deletions
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc index a189466..442cef8 100644 --- a/reference_model/src/ops/data_layout.cc +++ b/reference_model/src/ops/data_layout.cc @@ -20,7 +20,7 @@ using namespace TosaReference; using namespace Eigen; using namespace tosa; -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -32,14 +32,14 @@ OpConcat<Rank, Dtype>::OpConcat(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Axis); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpConcat<Rank, Dtype>::~OpConcat() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpConcat<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -100,7 +100,7 @@ int OpConcat<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpConcat<Rank, Dtype>::eval() { @@ -124,7 +124,7 @@ int OpConcat<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -136,12 +136,12 @@ OpPad<Rank, Dtype>::OpPad(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Pad); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpPad<Rank, Dtype>::~OpPad() { } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpPad<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -185,22 +185,23 @@ int OpPad<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpPad<Rank, Dtype>::eval() { InEigenType pad_value = 0; switch (Dtype) { - case DType_BOOL: - case DType_INT8: - case DType_INT16: - case DType_INT32: + case TOSA_REF_TYPE_BOOL: + case TOSA_REF_TYPE_INT8: + case TOSA_REF_TYPE_INT16: + case TOSA_REF_TYPE_INT32: pad_value = (InEigenType)attribute->pad_const_int(); break; - case DType_FP16: - case DType_BF16: - case DType_FP32: + case TOSA_REF_TYPE_FP16: + case TOSA_REF_TYPE_BF16: + case TOSA_REF_TYPE_FP32: + case TOSA_REF_TYPE_FP64: pad_value = (InEigenType)attribute->pad_const_fp(); break; default: @@ -213,7 +214,7 @@ int OpPad<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int InRank, int OutRank, DType Dtype> +template <int InRank, int OutRank, TOSA_REF_TYPE Dtype> OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -225,14 +226,14 @@ OpReshape<InRank, OutRank, Dtype>::OpReshape(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Reshape); } -template <int InRank, int OutRank, DType Dtype> +template <int InRank, int OutRank, TOSA_REF_TYPE Dtype> OpReshape<InRank, OutRank, Dtype>::~OpReshape() { if (attribute) delete attribute; } -template <int InRank, int OutRank, DType Dtype> +template <int InRank, int OutRank, TOSA_REF_TYPE Dtype> int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -270,7 +271,7 @@ int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes() return 0; } -template <int InRank, int OutRank, DType Dtype> +template <int InRank, int OutRank, TOSA_REF_TYPE Dtype> int OpReshape<InRank, OutRank, Dtype>::eval() { for (int32_t d = 0; d < OutRank; d++) @@ -313,7 +314,7 @@ int OpReshape<InRank, OutRank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -325,14 +326,14 @@ OpReverse<Rank, Dtype>::OpReverse(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Axis); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpReverse<Rank, Dtype>::~OpReverse() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReverse<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -376,7 +377,7 @@ int OpReverse<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpReverse<Rank, Dtype>::eval() { out->getTensor() = in->getTensor().reverse(reverse_array); @@ -384,7 +385,7 @@ int OpReverse<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -396,14 +397,14 @@ OpSlice<Rank, Dtype>::OpSlice(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Slice); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpSlice<Rank, Dtype>::~OpSlice() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpSlice<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -449,7 +450,7 @@ int OpSlice<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpSlice<Rank, Dtype>::eval() { out->getTensor() = in->getTensor().slice(begin_array, size_array); @@ -457,7 +458,7 @@ int OpSlice<Rank, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -469,14 +470,14 @@ OpTileBase<Rank, Dtype>::OpTileBase(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Tile); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpTileBase<Rank, Dtype>::~OpTileBase() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpTileBase<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -518,14 +519,14 @@ int OpTileBase<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpTile<Rank, Dtype>::eval() { // primary template shouldn't be called - FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]); + FATAL_ERROR("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNameTOSAREFTYPE(Dtype)); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpTile<1, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -537,7 +538,7 @@ int OpTile<1, Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpTile<2, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -553,7 +554,7 @@ int OpTile<2, Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpTile<3, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -573,7 +574,7 @@ int OpTile<3, Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpTile<4, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -597,7 +598,7 @@ int OpTile<4, Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpTile<5, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -626,7 +627,7 @@ int OpTile<5, Dtype>::eval() return GraphNode::eval(); } -template <DType Dtype> +template <TOSA_REF_TYPE Dtype> int OpTile<6, Dtype>::eval() { for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) @@ -659,7 +660,7 @@ int OpTile<6, Dtype>::eval() return GraphNode::eval(); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -671,13 +672,13 @@ OpTranspose<Rank, Dtype>::OpTranspose(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Transpose); } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> OpTranspose<Rank, Dtype>::~OpTranspose() { if (attribute) delete attribute; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpTranspose<Rank, Dtype>::checkTensorAttributes() { // Check Tosa Level @@ -727,7 +728,7 @@ int OpTranspose<Rank, Dtype>::checkTensorAttributes() return 0; } -template <int Rank, DType Dtype> +template <int Rank, TOSA_REF_TYPE Dtype> int OpTranspose<Rank, Dtype>::eval() { out->getTensor() = in->getTensor().shuffle(perm_array); @@ -743,6 +744,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FP64) DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BF16); @@ -751,6 +753,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FP64); DEF_INSTANTIATE_RESHAPE(OpReshape, FP16); DEF_INSTANTIATE_RESHAPE(OpReshape, BF16); @@ -759,6 +762,7 @@ DEF_INSTANTIATE_RESHAPE(OpReshape, INT8); DEF_INSTANTIATE_RESHAPE(OpReshape, INT16); DEF_INSTANTIATE_RESHAPE(OpReshape, INT32); DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL); +DEF_INSTANTIATE_RESHAPE(OpReshape, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BF16); @@ -767,6 +771,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BF16); @@ -775,6 +780,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpSlice, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BF16); @@ -783,6 +789,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTileBase, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BF16); @@ -791,6 +798,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTile, FP64); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BF16); @@ -799,3 +807,4 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpTranspose, FP64); |