From e45ebd0d7d2f7f90bc9a3f39444c392ee0526815 Mon Sep 17 00:00:00 2001 From: Davide Date: Tue, 30 Dec 2025 16:49:09 +0100 Subject: [PATCH] add plugins to cross_val and loo --- ezyrb/reducedordermodel.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/ezyrb/reducedordermodel.py b/ezyrb/reducedordermodel.py index 31e2f83..ce61873 100644 --- a/ezyrb/reducedordermodel.py +++ b/ezyrb/reducedordermodel.py @@ -451,7 +451,8 @@ def kfold_cv_error(self, n_splits, *args, norm=np.linalg.norm, **kwargs): for train_index, test_index in kf.split(self.database): new_db = self.database[train_index] rom = type(self)(new_db, copy.deepcopy(self.reduction), - copy.deepcopy(self.approximation)).fit( + copy.deepcopy(self.approximation), + plugins=[copy.deepcopy(p) for p in self.plugins]).fit( *args, **kwargs) error.append(rom.test_error(self.database[test_index], norm)) @@ -487,7 +488,8 @@ def loo_error(self, *args, norm=np.linalg.norm, **kwargs): new_db = self.database[indeces] test_db = self.database[~indeces] rom = type(self)(new_db, copy.deepcopy(self.reduction), - copy.deepcopy(self.approximation)).fit() + copy.deepcopy(self.approximation), + plugins=[copy.deepcopy(p) for p in self.plugins]).fit() error[j] = rom.test_error(test_db) @@ -860,6 +862,7 @@ def kfold_cv_error(self, n_splits, *args, norm=np.linalg.norm, relative=True, kf = KFold(n_splits=n_splits) for train_index, test_index in kf.split(self.database): new_db = self.database[train_index] + # TODO: Fix plugins handling - should pass: plugins=[copy.deepcopy(p) for p in self.plugins] rom = type(self)(new_db, copy.deepcopy(self.reduction), copy.deepcopy(self.approximation)).fit( *args, **kwargs) @@ -896,6 +899,7 @@ def loo_error(self, *args, norm=np.linalg.norm, **kwargs): new_db = self.database[indeces] test_db = self.database[~indeces] + # TODO: Fix plugins handling - should pass: plugins=[copy.deepcopy(p) for p in self.plugins] rom = type(self)(new_db, copy.deepcopy(self.reduction), copy.deepcopy(self.approximation)).fit()