Skip to content

Commit 635582b

Browse files
committed
-- added test cases for mdn and auto int
-- added autoint -- refactored a few things
1 parent d8188a0 commit 635582b

File tree

13 files changed

+628
-131
lines changed

13 files changed

+628
-131
lines changed

examples/to_test_regression.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pytorch_tabular.models.category_embedding.config import (
1313
CategoryEmbeddingModelConfig,
1414
)
15+
from pytorch_tabular.models import AutoIntModel, AutoIntConfig
1516

1617
from pytorch_tabular.models.mixture_density import (
1718
CategoryEmbeddingMDNConfig, MixtureDensityHeadConfig, NODEMDNConfig
@@ -33,6 +34,8 @@
3334
dataset = fetch_california_housing(data_home="data", as_frame=True)
3435
dataset.frame["HouseAgeBin"] = pd.qcut(dataset.frame["HouseAge"], q=4)
3536
dataset.frame.HouseAgeBin = "age_" + dataset.frame.HouseAgeBin.cat.codes.astype(str)
37+
dataset.frame["AveRoomsBin"] = pd.qcut(dataset.frame["AveRooms"], q=3)
38+
dataset.frame.AveRoomsBin = "av_rm_" + dataset.frame.AveRoomsBin.cat.codes.astype(str)
3639

3740
test_idx = dataset.frame.sample(int(0.2 * len(dataset.frame)), random_state=42).index
3841
test = dataset.frame[dataset.frame.index.isin(test_idx)]
@@ -49,7 +52,7 @@
4952
"Longitude",
5053
],
5154
# continuous_cols=[],
52-
categorical_cols=["HouseAgeBin"],
55+
categorical_cols=["HouseAgeBin","AveRoomsBin"],
5356
continuous_feature_transform=None, # "yeo-johnson",
5457
normalize_continuous_features=True,
5558
)
@@ -61,8 +64,9 @@
6164
# mdn_config = mdn_config
6265
# )
6366
# # model_config.validate()
64-
model_config = NodeConfig(task="regression", depth=2, embed_categorical=False)
65-
trainer_config = TrainerConfig(checkpoints=None, max_epochs=5, gpus=1, profiler=None)
67+
# model_config = CategoryEmbeddingModelConfig(task="regression")
68+
model_config = AutoIntConfig(task="regression", deep_layers=True, embedding_dropout=0.2, batch_norm_continuous_input=True)
69+
trainer_config = TrainerConfig(checkpoints=None, max_epochs=25, gpus=1, profiler=None, fast_dev_run=False, auto_lr_find=True)
6670
# experiment_config = ExperimentConfig(
6771
# project_name="DeepGMM_test",
6872
# run_name="wand_debug",

pytorch_tabular/models/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@
88
MixtureDensityHeadConfig,
99
NODEMDNConfig,
1010
NODEMDN,
11+
AutoIntMDN,
12+
AutoIntMDNConfig
1113
)
14+
from .autoint import AutoIntConfig, AutoIntModel
1215
from .base_model import BaseModel
13-
from . import category_embedding, node, mixture_density, tabnet
16+
from . import category_embedding, node, mixture_density, tabnet, autoint
1417

1518
__all__ = [
1619
"CategoryEmbeddingModel",
@@ -26,8 +29,13 @@
2629
"MixtureDensityHeadConfig",
2730
"NODEMDNConfig",
2831
"NODEMDN",
32+
"AutoIntMDN",
33+
"AutoIntMDNConfig",
34+
"AutoIntConfig",
35+
"AutoIntModel",
2936
"category_embedding",
3037
"node",
3138
"mixture_density",
3239
"tabnet",
40+
"autoint",
3341
]
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .category_embedding_model import CategoryEmbeddingModel, FeedForwardBackbone
2-
from .config import CategoryEmbeddingModelConfig
1+
from .autoint import AutoIntBackbone, AutoIntModel
2+
from .config import AutoIntConfig
33

4-
__all__ = ["CategoryEmbeddingModel", "CategoryEmbeddingModelConfig", "FeedForwardBackbone"]
4+
__all__ = ["AutoIntModel", "AutoIntBackbone", "AutoIntConfig"]

pytorch_tabular/models/autoint/autoint.py

Lines changed: 70 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import torch
1010
import torch.nn as nn
1111
from omegaconf import DictConfig
12-
from pytorch_tabular.utils import _initialize_layers
12+
13+
from pytorch_tabular.utils import _initialize_layers, _linear_dropout_bn
1314

1415
from ..base_model import BaseModel
1516

@@ -21,40 +22,32 @@ def __init__(self, config: DictConfig, **kwargs):
2122
self.embedding_cat_dim = sum([y for x, y in config.embedding_dims])
2223
super().__init__(config, **kwargs)
2324

24-
def _linear_dropout_bn(self, in_units, out_units, activation, dropout):
25-
layers = []
26-
if self.hparams.use_batch_norm:
27-
layers.append(nn.BatchNorm1d(num_features=in_units))
28-
linear = nn.Linear(in_units, out_units)
29-
_initialize_layers(self.hparams, linear)
30-
layers.extend([linear, activation()])
31-
if dropout != 0:
32-
layers.append(nn.Dropout(dropout))
33-
return layers
34-
3525
def _build_network(self):
36-
# Embedding layers
26+
# Category Embedding layers
3727
self.cat_embedding_layers = nn.ModuleList(
38-
[nn.Embedding(x, y) for x, y in self.hparams.cat_embedding_dims]
39-
)
40-
self.cont_embedding_layers = nn.ModuleList(
4128
[
42-
nn.Embedding(1, self.hparams.cont_embedding_dim)
43-
for i in range(self.hparams.continuous_dim)
29+
nn.Embedding(cardinality, self.hparams.embedding_dim)
30+
for cardinality in self.hparams.categorical_cardinality
4431
]
4532
)
33+
if self.hparams.batch_norm_continuous_input:
34+
self.normalizing_batch_norm = nn.BatchNorm1d(self.hparams.continuous_dim)
35+
# Continuous Embedding Layer
36+
self.cont_embedding_layer = nn.Embedding(
37+
self.hparams.continuous_dim, self.hparams.embedding_dim
38+
)
4639
if self.hparams.embedding_dropout != 0 and self.embedding_cat_dim != 0:
4740
self.embed_dropout = nn.Dropout(self.hparams.embedding_dropout)
48-
# if self.hparams.use_batch_norm:
49-
# self.normalizing_batch_norm = nn.BatchNorm1d(self.hparams.continuous_dim+self.hparams.embedding_cat_dim)
41+
# Deep Layers
42+
_curr_units = self.hparams.embedding_dim
5043
if self.hparams.deep_layers:
5144
activation = getattr(nn, self.hparams.activation)
5245
# Linear Layers
5346
layers = []
54-
_curr_units = self.hparams.continuous_dim + self.embedding_cat_dim
5547
for units in self.hparams.layers.split("-"):
5648
layers.extend(
57-
self._linear_dropout_bn(
49+
_linear_dropout_bn(
50+
self.hparams,
5851
_curr_units,
5952
int(units),
6053
activation,
@@ -63,9 +56,10 @@ def _build_network(self):
6356
)
6457
_curr_units = int(units)
6558
self.linear_layers = nn.Sequential(*layers)
66-
else:
67-
_curr_units = self.hparams.continuous_dim + self.embedding_cat_dim
68-
59+
# Projection to Multi-Headed Attention Dims
60+
self.attn_proj = nn.Linear(_curr_units, self.hparams.attn_embed_dim)
61+
_initialize_layers(self.hparams, self.attn_proj)
62+
# Multi-Headed Attention Layers
6963
self.self_attns = nn.ModuleList(
7064
[
7165
nn.MultiheadAttention(
@@ -76,15 +70,56 @@ def _build_network(self):
7670
for _ in range(self.hparams.num_attn_blocks)
7771
]
7872
)
79-
self.atten_output_dim = (
80-
len(self.hparams.continuous_cols + self.hparams.categorical_cols)
81-
* self.hparams.atten_embed_dim
82-
)
83-
73+
if self.hparams.has_residuals:
74+
self.V_res_embedding = torch.nn.Linear(
75+
_curr_units, self.hparams.attn_embed_dim
76+
)
77+
self.output_dim = (
78+
self.hparams.continuous_dim + self.hparams.categorical_dim
79+
) * self.hparams.attn_embed_dim
8480

85-
def forward(self, x):
86-
x = self.linear_layers(x)
87-
return x
81+
def forward(self, x: Dict):
82+
# (B, N)
83+
continuous_data, categorical_data = x["continuous"], x["categorical"]
84+
x = None
85+
if self.embedding_cat_dim != 0:
86+
x_cat = [
87+
embedding_layer(categorical_data[:, i]).unsqueeze(1)
88+
for i, embedding_layer in enumerate(self.cat_embedding_layers)
89+
]
90+
# (B, N, E)
91+
x = torch.cat(x_cat, 1)
92+
if self.hparams.continuous_dim > 0:
93+
cont_idx = (
94+
torch.arange(self.hparams.continuous_dim)
95+
.expand(continuous_data.size(0), -1)
96+
.to(self.device)
97+
)
98+
if self.hparams.batch_norm_continuous_input:
99+
continuous_data = self.normalizing_batch_norm(continuous_data)
100+
x_cont = torch.mul(
101+
continuous_data.unsqueeze(2),
102+
self.cont_embedding_layer(cont_idx),
103+
)
104+
# (B, N, E)
105+
x = x_cont if x is None else torch.cat([x, x_cont], 1)
106+
if self.hparams.embedding_dropout != 0 and self.embedding_cat_dim != 0:
107+
x = self.embed_dropout(x)
108+
if self.hparams.deep_layers:
109+
x = self.linear_layers(x)
110+
# (N, B, E*) --> E* is the Attn Dimention
111+
cross_term = self.attn_proj(x).transpose(0, 1)
112+
for self_attn in self.self_attns:
113+
cross_term, _ = self_attn(cross_term, cross_term, cross_term)
114+
# (B, N, E*)
115+
cross_term = cross_term.transpose(0, 1)
116+
if self.hparams.has_residuals:
117+
# (B, N, E*) --> Projecting Embedded input to Attention sub-space
118+
V_res = self.V_res_embedding(x)
119+
cross_term = cross_term + V_res
120+
# (B, NxE*)
121+
cross_term = nn.ReLU()(cross_term).reshape(-1, self.output_dim)
122+
return cross_term
88123

89124

90125
class AutoIntModel(BaseModel):
@@ -94,46 +129,18 @@ def __init__(self, config: DictConfig, **kwargs):
94129
super().__init__(config, **kwargs)
95130

96131
def _build_network(self):
97-
# Embedding layers
98-
self.embedding_layers = nn.ModuleList(
99-
[nn.Embedding(x, y) for x, y in self.hparams.embedding_dims]
100-
)
101-
# Continuous Layers
102-
if self.hparams.batch_norm_continuous_input:
103-
self.normalizing_batch_norm = nn.BatchNorm1d(self.hparams.continuous_dim)
104132
# Backbone
105133
self.backbone = AutoIntBackbone(self.hparams)
134+
self.dropout = nn.Dropout(self.hparams.dropout)
106135
# Adding the last layer
107136
self.output_layer = nn.Linear(
108137
self.backbone.output_dim, self.hparams.output_dim
109138
) # output_dim auto-calculated from other config
110139
_initialize_layers(self.hparams, self.output_layer)
111140

112-
def unpack_input(self, x: Dict):
113-
continuous_data, categorical_data = x["continuous"], x["categorical"]
114-
if self.embedding_cat_dim != 0:
115-
x = []
116-
# for i, embedding_layer in enumerate(self.embedding_layers):
117-
# x.append(embedding_layer(categorical_data[:, i]))
118-
x = [
119-
embedding_layer(categorical_data[:, i])
120-
for i, embedding_layer in enumerate(self.embedding_layers)
121-
]
122-
x = torch.cat(x, 1)
123-
124-
if self.hparams.continuous_dim != 0:
125-
if self.hparams.batch_norm_continuous_input:
126-
continuous_data = self.normalizing_batch_norm(continuous_data)
127-
128-
if self.embedding_cat_dim != 0:
129-
x = torch.cat([x, continuous_data], 1)
130-
else:
131-
x = continuous_data
132-
return x
133-
134141
def forward(self, x: Dict):
135-
x = self.unpack_input(x)
136142
x = self.backbone(x)
143+
x = self.dropout(x)
137144
y_hat = self.output_layer(x)
138145
if (self.hparams.task == "regression") and (
139146
self.hparams.target_range is not None

0 commit comments

Comments
 (0)