Coverage for anfis_toolbox / optim / base.py: 100%

81 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-05 18:47 -0300

1"""Base classes and interfaces for ANFIS trainers. 

2 

3Defines the shared training loop and contracts used by all optimizers. Concrete 

4trainers specialize the ``train_step`` (and related helpers) while the base 

5class takes care of batching, epoch bookkeeping, optional validation, and 

6logging. 

7 

8Model contract expected by trainers: 

9- For pure backprop trainers (e.g., SGD/Adam): the model must provide 

10 ``reset_gradients()``, ``forward(X)``, ``backward(dL_dy)``, and 

11 ``update_parameters(lr)``. 

12- For the HybridTrainer, the model must expose the usual ANFIS layers 

13 (``membership_layer``, ``rule_layer``, ``normalization_layer``, 

14 ``consequent_layer``) to build the least-squares system internally. 

15""" 

16 

17from __future__ import annotations 

18 

19import logging 

20from abc import ABC, abstractmethod 

21from typing import Any, TypeAlias 

22 

23import numpy as np 

24 

25from ..losses import LossFunction, resolve_loss 

26from ..model import TSKANFIS, TrainingHistory, TSKANFISClassifier 

27 

28ModelLike: TypeAlias = TSKANFIS | TSKANFISClassifier 

29 

30__all__ = ["BaseTrainer", "ModelLike"] 

31 

32 

33class BaseTrainer(ABC): 

34 """Shared training loop for ANFIS trainers.""" 

35 

36 def fit( 

37 self, 

38 model: ModelLike, 

39 X: np.ndarray, 

40 y: np.ndarray, 

41 *, 

42 validation_data: tuple[np.ndarray, np.ndarray] | None = None, 

43 validation_frequency: int = 1, 

44 ) -> TrainingHistory: 

45 """Train ``model`` on ``(X, y)`` and optionally evaluate on validation data. 

46 

47 Returns a dictionary containing the per-epoch training losses and, when 

48 ``validation_data`` is provided, the validation losses (aligned with the 

49 training epochs; epochs without validation are recorded as ``None``). 

50 """ 

51 if validation_frequency < 1: 

52 raise ValueError("validation_frequency must be >= 1") 

53 

54 X_train, y_train = self._prepare_training_data(model, X, y) 

55 state = self.init_state(model, X_train, y_train) 

56 

57 prepared_val: tuple[np.ndarray, np.ndarray] | None = None 

58 if validation_data is not None: 

59 prepared_val = self._prepare_validation_data(model, *validation_data) 

60 

61 epochs = int(getattr(self, "epochs", 1)) 

62 batch_size = getattr(self, "batch_size", None) 

63 shuffle = bool(getattr(self, "shuffle", True)) 

64 verbose = bool(getattr(self, "verbose", False)) 

65 

66 train_history: list[float] = [] 

67 val_history: list[float | None] = [] if prepared_val is not None else [] 

68 

69 n_samples = X_train.shape[0] 

70 for epoch_idx in range(epochs): 

71 epoch_losses: list[float] = [] 

72 if batch_size is None: 

73 loss, state = self.train_step(model, X_train, y_train, state) 

74 epoch_losses.append(float(loss)) 

75 else: 

76 indices = np.arange(n_samples) 

77 if shuffle: 

78 np.random.shuffle(indices) 

79 for start in range(0, n_samples, batch_size): 

80 end = start + batch_size 

81 batch_idx = indices[start:end] 

82 loss, state = self.train_step( 

83 model, 

84 X_train[batch_idx], 

85 y_train[batch_idx], 

86 state, 

87 ) 

88 epoch_losses.append(float(loss)) 

89 

90 epoch_loss = float(np.mean(epoch_losses)) if epoch_losses else 0.0 

91 train_history.append(epoch_loss) 

92 

93 val_loss: float | None = None 

94 if prepared_val is not None: 

95 if (epoch_idx + 1) % validation_frequency == 0: 

96 X_val, y_val = prepared_val 

97 val_loss = float(self.compute_loss(model, X_val, y_val)) 

98 val_history.append(val_loss) 

99 

100 self._log_epoch(epoch_idx, epoch_loss, val_loss, verbose) 

101 

102 result: TrainingHistory = {"train": train_history} 

103 if prepared_val is not None: 

104 result["val"] = val_history 

105 return result 

106 

107 @abstractmethod 

108 def init_state(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> Any: # pragma: no cover - abstract 

109 """Initialize and return any optimizer-specific state. 

110 

111 Called once before training begins. Trainers that don't require state may 

112 return None. 

113 

114 Parameters: 

115 model: The model to be trained. 

116 X (np.ndarray): The full training inputs. 

117 y (np.ndarray): The full training targets. 

118 

119 Returns: 

120 Any: Optimizer state (or None) to be threaded through ``train_step``. 

121 """ 

122 raise NotImplementedError 

123 

124 @abstractmethod 

125 def train_step( 

126 self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray, state: Any 

127 ) -> tuple[float, Any]: # pragma: no cover - abstract 

128 """Perform a single training step on a batch and return (loss, new_state). 

129 

130 Parameters: 

131 model: The model to be trained. 

132 Xb (np.ndarray): A batch of inputs. 

133 yb (np.ndarray): A batch of targets. 

134 state: Optimizer state produced by ``init_state``. 

135 

136 Returns: 

137 tuple[float, Any]: The batch loss and the updated optimizer state. 

138 """ 

139 raise NotImplementedError 

140 

141 @abstractmethod 

142 def compute_loss(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> float: # pragma: no cover - abstract 

143 """Compute loss for the provided data without mutating the model.""" 

144 

145 # ------------------------------------------------------------------ 

146 # Helpers 

147 # ------------------------------------------------------------------ 

148 def _prepare_training_data(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]: 

149 """Prepare training data by converting to arrays and using loss function.""" 

150 loss_fn = self._get_loss_fn() 

151 X_arr = np.asarray(X, dtype=float) 

152 y_arr = loss_fn.prepare_targets(y, model=model) 

153 if y_arr.shape[0] != X_arr.shape[0]: 

154 raise ValueError("Target array must have same number of rows as X") 

155 return X_arr, y_arr 

156 

157 def _prepare_validation_data( 

158 self, 

159 model: ModelLike, 

160 X_val: np.ndarray, 

161 y_val: np.ndarray, 

162 ) -> tuple[np.ndarray, np.ndarray]: 

163 """Prepare validation data using the same logic as training data.""" 

164 loss_fn = self._get_loss_fn() 

165 X_arr = np.asarray(X_val, dtype=float) 

166 y_arr = loss_fn.prepare_targets(y_val, model=model) 

167 if y_arr.shape[0] != X_arr.shape[0]: 

168 raise ValueError("Target array must have same number of rows as X") 

169 return X_arr, y_arr 

170 

171 def _get_loss_fn(self) -> LossFunction: 

172 """Get or initialize the loss function (always returns a valid LossFunction).""" 

173 if hasattr(self, "__dataclass_fields__") and "_loss_fn" in self.__dataclass_fields__: 

174 # For dataclass trainers with _loss_fn field, initialize it once 

175 if not hasattr(self, "_loss_fn"): 

176 loss_attr = getattr(self, "loss", None) 

177 self._loss_fn = resolve_loss(loss_attr) 

178 return self._loss_fn 

179 else: 

180 # For non-dataclass trainers, resolve on the fly 

181 loss_attr = getattr(self, "loss", None) 

182 return resolve_loss(loss_attr) 

183 

184 def _log_epoch( 

185 self, 

186 epoch_idx: int, 

187 train_loss: float, 

188 val_loss: float | None, 

189 verbose: bool, 

190 ) -> None: 

191 if not verbose: 

192 return 

193 logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}") 

194 message = f"Epoch {epoch_idx + 1} - train_loss: {train_loss:.6f}" 

195 if val_loss is not None: 

196 message += f" - val_loss: {val_loss:.6f}" 

197 logger.info(message)