aboutsummaryrefslogtreecommitdiff
path: root/src/dynamic_fusion
diff options
context:
space:
mode:
Diffstat (limited to 'src/dynamic_fusion')
-rw-r--r--src/dynamic_fusion/sketch/gpu/operators/GpuPool2d.cpp11
-rw-r--r--src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.cpp13
2 files changed, 18 insertions, 6 deletions
diff --git a/src/dynamic_fusion/sketch/gpu/operators/GpuPool2d.cpp b/src/dynamic_fusion/sketch/gpu/operators/GpuPool2d.cpp
index a07ad00155..c602f45164 100644
--- a/src/dynamic_fusion/sketch/gpu/operators/GpuPool2d.cpp
+++ b/src/dynamic_fusion/sketch/gpu/operators/GpuPool2d.cpp
@@ -60,6 +60,17 @@ bool GpuPool2dSettings::mixed_precision() const
return _mixed_precision;
}
+GpuPool2dSettings GpuPool2dSettings::use_inf_as_limit(bool use_inf_as_limit)
+{
+ _use_inf_as_limit = use_inf_as_limit;
+ return *this;
+}
+
+bool GpuPool2dSettings::use_inf_as_limit() const
+{
+ return _use_inf_as_limit;
+}
+
Status GpuPool2d::validate_op(const GpuWorkloadSketch &sketch,
const ITensorInfo *src,
const ITensorInfo *dst,
diff --git a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.cpp b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.cpp
index 5df4438afe..bbff8ba98f 100644
--- a/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.cpp
+++ b/src/dynamic_fusion/sketch/gpu/template_writer/cl/ClTemplatePool2d.cpp
@@ -390,11 +390,12 @@ TagLUT ClTemplatePool2d::get_tag_lut(const GpuKernelVariableTable &vtable, const
lut["meta_kernel_id"] = id();
// Retrieve relevant data
- const auto padding = _attributes.pad();
- const auto stride = _attributes.stride();
- const auto pool_size = _attributes.pool_size();
- const auto data_type = _src->data_type();
- const auto use_fp_mixed_precision = (_src->data_type() == DataType::F16) && _settings.mixed_precision() && _attributes.pool_type() != PoolingType::MAX;
+ const auto padding = _attributes.pad();
+ const auto stride = _attributes.stride();
+ const auto pool_size = _attributes.pool_size();
+ const auto data_type = _src->data_type();
+ const auto use_fp_mixed_precision = (_src->data_type() == DataType::F16) && _settings.mixed_precision() && _attributes.pool_type() != PoolingType::MAX;
+ const std::string max_initial_value = _settings.use_inf_as_limit() ? "(-INFINITY)" : float_to_string_with_full_precision(std::numeric_limits<float>::lowest());
// pool specific
lut["STRIDE_X"] = stride.x();
@@ -409,7 +410,7 @@ TagLUT ClTemplatePool2d::get_tag_lut(const GpuKernelVariableTable &vtable, const
lut["DATA_TYPE"] = get_cl_type_from_data_type(data_type);
lut["SRC_WIDTH"] = _src->dimension(width_idx);
lut["SRC_HEIGHT"] = _src->dimension(height_idx);
- lut["INITIAL_VALUE"] = (_attributes.pool_type() == PoolingType::MAX) ? float_to_string_with_full_precision(std::numeric_limits<float>::lowest()) : std::string("0");
+ lut["INITIAL_VALUE"] = (_attributes.pool_type() == PoolingType::MAX) ? max_initial_value : std::string("0");
// Tensor specific data
lut["DST_HEIGHT"] = _dst->dimension(height_idx);