aboutsummaryrefslogtreecommitdiff
path: root/compute_kernel_writer/src/Tensor3dMapper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'compute_kernel_writer/src/Tensor3dMapper.cpp')
-rw-r--r--compute_kernel_writer/src/Tensor3dMapper.cpp103
1 files changed, 59 insertions, 44 deletions
diff --git a/compute_kernel_writer/src/Tensor3dMapper.cpp b/compute_kernel_writer/src/Tensor3dMapper.cpp
index 60f4f57466..7384b924da 100644
--- a/compute_kernel_writer/src/Tensor3dMapper.cpp
+++ b/compute_kernel_writer/src/Tensor3dMapper.cpp
@@ -27,114 +27,129 @@
#include "ckw/Error.h"
#include "ckw/types/TensorSamplerTypes.h"
#include "src/ITensor.h"
-
+#include "src/ITile.h"
namespace ckw
{
-Tensor3dMapper::Tensor3dMapper(ITensor *tensor, TensorSampler sampler)
- : _tensor(tensor), _sampler(sampler)
+Tensor3dMapper::Tensor3dMapper(ITensor *tensor, TensorSamplerFormat format)
+ : _tensor(tensor), _format(format)
{
}
-std::string Tensor3dMapper::tensor_component_x() const
+TileVariable Tensor3dMapper::dim_x() const
{
- const TensorSamplerFormat format = _sampler.format();
- switch(format)
+ switch(_format)
{
case TensorSamplerFormat::Dim0_Dim1xDim2_1:
case TensorSamplerFormat::Dim0_Dim1_Dim2:
- return _tensor->component(TensorComponentType::Dim0).scalar(0,0).str;
+ return _tensor->component(TensorComponentType::Dim0).scalar(0, 0);
default:
CKW_THROW_MSG("Unsupported tensor format");
- return "";
+ return _tensor->component(TensorComponentType::Unknown).scalar(0, 0);
}
}
-std::string Tensor3dMapper::tensor_component_y() const
+TileVariable Tensor3dMapper::dim_y() const
{
- const TensorSamplerFormat format = _sampler.format();
- switch(format)
+ switch(_format)
{
case TensorSamplerFormat::Dim0_Dim1xDim2_1:
- return _tensor->component(TensorComponentType::Dim1xDim2).scalar(0,0).str;
+ return _tensor->component(TensorComponentType::Dim1xDim2).scalar(0, 0);
case TensorSamplerFormat::Dim0_Dim1_Dim2:
- return _tensor->component(TensorComponentType::Dim1).scalar(0,0).str;
+ return _tensor->component(TensorComponentType::Dim1).scalar(0, 0);
default:
CKW_THROW_MSG("Unsupported tensor format");
- return "";
+ return _tensor->component(TensorComponentType::Unknown).scalar(0, 0);
}
}
-std::string Tensor3dMapper::tensor_component_z() const
+TileVariable Tensor3dMapper::dim_z() const
{
- const TensorSamplerFormat format = _sampler.format();
- switch(format)
+ TileVariable dim_one;
+
+ switch(_format)
{
case TensorSamplerFormat::Dim0_Dim1xDim2_1:
- return "1";
+ dim_one = _tensor->component(TensorComponentType::Dim3).scalar(0, 0);
+ dim_one.str = "1";
+ return dim_one;
case TensorSamplerFormat::Dim0_Dim1_Dim2:
- return _tensor->component(TensorComponentType::Dim2).scalar(0,0).str;
+ return _tensor->component(TensorComponentType::Dim2).scalar(0, 0);
default:
CKW_THROW_MSG("Unsupported tensor format");
- return "";
+ return _tensor->component(TensorComponentType::Unknown).scalar(0, 0);
}
}
-std::string Tensor3dMapper::tensor_component_stride_x() const
+TileVariable Tensor3dMapper::dim_batch() const
{
- const TensorSamplerFormat format = _sampler.format();
- switch(format)
+ TileVariable dim_one;
+
+ switch(_format)
{
case TensorSamplerFormat::Dim0_Dim1xDim2_1:
case TensorSamplerFormat::Dim0_Dim1_Dim2:
- return _tensor->component(TensorComponentType::Stride0).scalar(0,0).str;
+ return _tensor->component(TensorComponentType::Dim3).scalar(0, 0);
default:
CKW_THROW_MSG("Unsupported tensor format");
- return "";
+ return _tensor->component(TensorComponentType::Unknown).scalar(0, 0);
}
}
-std::string Tensor3dMapper::tensor_component_stride_y() const
+TileVariable Tensor3dMapper::stride_x() const
{
- const TensorSamplerFormat format = _sampler.format();
- switch(format)
+ switch(_format)
{
case TensorSamplerFormat::Dim0_Dim1xDim2_1:
case TensorSamplerFormat::Dim0_Dim1_Dim2:
- return _tensor->component(TensorComponentType::Stride1).scalar(0,0).str;
+ return _tensor->component(TensorComponentType::Stride0).scalar(0, 0);
default:
CKW_THROW_MSG("Unsupported tensor format");
- return "";
+ return _tensor->component(TensorComponentType::Unknown).scalar(0, 0);
}
}
-std::string Tensor3dMapper::tensor_component_stride_z() const
+TileVariable Tensor3dMapper::stride_y() const
{
- const TensorSamplerFormat format = _sampler.format();
- switch(format)
+ switch(_format)
{
case TensorSamplerFormat::Dim0_Dim1xDim2_1:
- return "0";
case TensorSamplerFormat::Dim0_Dim1_Dim2:
- return _tensor->component(TensorComponentType::Stride2).scalar(0,0).str;
+ return _tensor->component(TensorComponentType::Stride1).scalar(0, 0);
default:
CKW_THROW_MSG("Unsupported tensor format");
- return "";
+ return _tensor->component(TensorComponentType::Unknown).scalar(0, 0);
}
}
-std::string Tensor3dMapper::tensor_component_stride_batch() const
+TileVariable Tensor3dMapper::stride_z() const
{
- return _tensor->component(TensorComponentType::Stride3).scalar(0,0).str;
-}
+ TileVariable stride_zero;
-TensorSampler Tensor3dMapper::sampler() const
-{
- return _sampler;
+ switch(_format)
+ {
+ case TensorSamplerFormat::Dim0_Dim1xDim2_1:
+ stride_zero = _tensor->component(TensorComponentType::Stride3).scalar(0, 0);
+ stride_zero.str = "0";
+ return stride_zero;
+ case TensorSamplerFormat::Dim0_Dim1_Dim2:
+ return _tensor->component(TensorComponentType::Stride2).scalar(0, 0);
+ default:
+ CKW_THROW_MSG("Unsupported tensor format");
+ return _tensor->component(TensorComponentType::Unknown).scalar(0, 0);
+ }
}
-ITensor *Tensor3dMapper::tensor() const
+TileVariable Tensor3dMapper::stride_batch() const
{
- return _tensor;
+ switch(_format)
+ {
+ case TensorSamplerFormat::Dim0_Dim1xDim2_1:
+ case TensorSamplerFormat::Dim0_Dim1_Dim2:
+ return _tensor->component(TensorComponentType::Stride3).scalar(0, 0);
+ default:
+ CKW_THROW_MSG("Unsupported tensor format");
+ return _tensor->component(TensorComponentType::Unknown).scalar(0, 0);
+ }
}
} // namespace ckw \ No newline at end of file