Skip to content

Commit 20c24c3

Browse files
committed
-- fixed an issue with torch.clip vs torch.clamp
1 parent b9790d4 commit 20c24c3

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ loaded_model = TabularModel.load_from_checkpoint("examples/basic")
108108
```
109109
## Blog
110110

111-
_TBD_
111+
[PyTorch Tabular – A Framework for Deep Learning for Tabular Data](https://deep-and-shallow.com/2021/01/27/pytorch-tabular-a-framework-for-deep-learning-for-tabular-data/)
112+
112113
## References and Citations
113114

114115
[1] Sergei Popov, Stanislav Morozov, Artem Babenko. [*"Neural Oblivious Decision Ensembles for Deep Learning on Tabular Data"*](https://arxiv.org/abs/1909.06312). arXiv:1909.06312 [cs.LG] (2019)

pytorch_tabular/models/base_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,15 @@ def calculate_loss(self, y, y_hat, tag):
101101

102102
def calculate_metrics(self, y, y_hat, tag):
103103
metrics = []
104-
y_hat = torch.clip(y_hat, min=0)
104+
y_hat = torch.clamp(y_hat, min=0)
105105
for metric, metric_str, metric_params in zip(self.metrics, self.hparams.metrics, self.hparams.metrics_params):
106106
if (self.hparams.task == "regression") and (self.hparams.output_dim > 1):
107107
_metrics = []
108108
for i in range(self.hparams.output_dim):
109109
if metric.__name__==pl.metrics.functional.mean_squared_log_error.__name__:
110110
# MSLE should only be used in strictly positive targets. It is undefined otherwise
111111
_metric = metric(
112-
torch.clip(y_hat[:, i], min=0), torch.clip(y[:, i], min=0), **metric_params
112+
torch.clamp(y_hat[:, i], min=0), torch.clamp(y[:, i], min=0), **metric_params
113113
)
114114
else:
115115
_metric = metric(y_hat[:, i], y[:, i], **metric_params)

0 commit comments

Comments
 (0)