Skip to content

Commit d8b061f

Browse files
committed
add test for parallelisation
1 parent 4b5546f commit d8b061f

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

imblearn/combine/tests/test_smote_enn.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,22 @@ def test_validate_estimator_default():
9898
assert_array_equal(y_resampled, y_gt)
9999

100100

101+
def test_parallelisation():
102+
# Check if default job count is 1
103+
smt = SMOTEENN(random_state=RND_SEED)
104+
smt._validate_estimator()
105+
assert smt.n_jobs == 1
106+
assert smt.smote_.n_jobs == 1
107+
assert smt.enn_.n_jobs == 1
108+
109+
# Check if job count is set
110+
smt = SMOTEENN(random_state=RND_SEED, n_jobs=8)
111+
smt._validate_estimator()
112+
assert smt.n_jobs == 8
113+
assert smt.smote_.n_jobs == 8
114+
assert smt.enn_.n_jobs == 8
115+
116+
101117
@pytest.mark.parametrize(
102118
"smote_params, err_msg",
103119
[({'smote': 'rnd'}, "smote needs to be a SMOTE"),

imblearn/combine/tests/test_smote_tomek.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,22 @@ def test_validate_estimator_default():
104104
assert_array_equal(y_resampled, y_gt)
105105

106106

107+
def test_parallelisation():
108+
# Check if default job count is 1
109+
smt = SMOTETomek(random_state=RND_SEED)
110+
smt._validate_estimator()
111+
assert smt.n_jobs == 1
112+
assert smt.smote_.n_jobs == 1
113+
assert smt.tomek_.n_jobs == 1
114+
115+
# Check if job count is set
116+
smt = SMOTETomek(random_state=RND_SEED, n_jobs=8)
117+
smt._validate_estimator()
118+
assert smt.n_jobs == 8
119+
assert smt.smote_.n_jobs == 8
120+
assert smt.tomek_.n_jobs == 8
121+
122+
107123
@pytest.mark.parametrize(
108124
"smote_params, err_msg",
109125
[({'smote': 'rnd'}, "smote needs to be a SMOTE"),

0 commit comments

Comments
 (0)