Coverage for anfis_toolbox / optim / base.py: 100%
81 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
1"""Base classes and interfaces for ANFIS trainers.
3Defines the shared training loop and contracts used by all optimizers. Concrete
4trainers specialize the ``train_step`` (and related helpers) while the base
5class takes care of batching, epoch bookkeeping, optional validation, and
6logging.
8Model contract expected by trainers:
9- For pure backprop trainers (e.g., SGD/Adam): the model must provide
10 ``reset_gradients()``, ``forward(X)``, ``backward(dL_dy)``, and
11 ``update_parameters(lr)``.
12- For the HybridTrainer, the model must expose the usual ANFIS layers
13 (``membership_layer``, ``rule_layer``, ``normalization_layer``,
14 ``consequent_layer``) to build the least-squares system internally.
15"""
17from __future__ import annotations
19import logging
20from abc import ABC, abstractmethod
21from typing import Any, TypeAlias
23import numpy as np
25from ..losses import LossFunction, resolve_loss
26from ..model import TSKANFIS, TrainingHistory, TSKANFISClassifier
28ModelLike: TypeAlias = TSKANFIS | TSKANFISClassifier
30__all__ = ["BaseTrainer", "ModelLike"]
33class BaseTrainer(ABC):
34 """Shared training loop for ANFIS trainers."""
36 def fit(
37 self,
38 model: ModelLike,
39 X: np.ndarray,
40 y: np.ndarray,
41 *,
42 validation_data: tuple[np.ndarray, np.ndarray] | None = None,
43 validation_frequency: int = 1,
44 ) -> TrainingHistory:
45 """Train ``model`` on ``(X, y)`` and optionally evaluate on validation data.
47 Returns a dictionary containing the per-epoch training losses and, when
48 ``validation_data`` is provided, the validation losses (aligned with the
49 training epochs; epochs without validation are recorded as ``None``).
50 """
51 if validation_frequency < 1:
52 raise ValueError("validation_frequency must be >= 1")
54 X_train, y_train = self._prepare_training_data(model, X, y)
55 state = self.init_state(model, X_train, y_train)
57 prepared_val: tuple[np.ndarray, np.ndarray] | None = None
58 if validation_data is not None:
59 prepared_val = self._prepare_validation_data(model, *validation_data)
61 epochs = int(getattr(self, "epochs", 1))
62 batch_size = getattr(self, "batch_size", None)
63 shuffle = bool(getattr(self, "shuffle", True))
64 verbose = bool(getattr(self, "verbose", False))
66 train_history: list[float] = []
67 val_history: list[float | None] = [] if prepared_val is not None else []
69 n_samples = X_train.shape[0]
70 for epoch_idx in range(epochs):
71 epoch_losses: list[float] = []
72 if batch_size is None:
73 loss, state = self.train_step(model, X_train, y_train, state)
74 epoch_losses.append(float(loss))
75 else:
76 indices = np.arange(n_samples)
77 if shuffle:
78 np.random.shuffle(indices)
79 for start in range(0, n_samples, batch_size):
80 end = start + batch_size
81 batch_idx = indices[start:end]
82 loss, state = self.train_step(
83 model,
84 X_train[batch_idx],
85 y_train[batch_idx],
86 state,
87 )
88 epoch_losses.append(float(loss))
90 epoch_loss = float(np.mean(epoch_losses)) if epoch_losses else 0.0
91 train_history.append(epoch_loss)
93 val_loss: float | None = None
94 if prepared_val is not None:
95 if (epoch_idx + 1) % validation_frequency == 0:
96 X_val, y_val = prepared_val
97 val_loss = float(self.compute_loss(model, X_val, y_val))
98 val_history.append(val_loss)
100 self._log_epoch(epoch_idx, epoch_loss, val_loss, verbose)
102 result: TrainingHistory = {"train": train_history}
103 if prepared_val is not None:
104 result["val"] = val_history
105 return result
107 @abstractmethod
108 def init_state(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> Any: # pragma: no cover - abstract
109 """Initialize and return any optimizer-specific state.
111 Called once before training begins. Trainers that don't require state may
112 return None.
114 Parameters:
115 model: The model to be trained.
116 X (np.ndarray): The full training inputs.
117 y (np.ndarray): The full training targets.
119 Returns:
120 Any: Optimizer state (or None) to be threaded through ``train_step``.
121 """
122 raise NotImplementedError
124 @abstractmethod
125 def train_step(
126 self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray, state: Any
127 ) -> tuple[float, Any]: # pragma: no cover - abstract
128 """Perform a single training step on a batch and return (loss, new_state).
130 Parameters:
131 model: The model to be trained.
132 Xb (np.ndarray): A batch of inputs.
133 yb (np.ndarray): A batch of targets.
134 state: Optimizer state produced by ``init_state``.
136 Returns:
137 tuple[float, Any]: The batch loss and the updated optimizer state.
138 """
139 raise NotImplementedError
141 @abstractmethod
142 def compute_loss(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> float: # pragma: no cover - abstract
143 """Compute loss for the provided data without mutating the model."""
145 # ------------------------------------------------------------------
146 # Helpers
147 # ------------------------------------------------------------------
148 def _prepare_training_data(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
149 """Prepare training data by converting to arrays and using loss function."""
150 loss_fn = self._get_loss_fn()
151 X_arr = np.asarray(X, dtype=float)
152 y_arr = loss_fn.prepare_targets(y, model=model)
153 if y_arr.shape[0] != X_arr.shape[0]:
154 raise ValueError("Target array must have same number of rows as X")
155 return X_arr, y_arr
157 def _prepare_validation_data(
158 self,
159 model: ModelLike,
160 X_val: np.ndarray,
161 y_val: np.ndarray,
162 ) -> tuple[np.ndarray, np.ndarray]:
163 """Prepare validation data using the same logic as training data."""
164 loss_fn = self._get_loss_fn()
165 X_arr = np.asarray(X_val, dtype=float)
166 y_arr = loss_fn.prepare_targets(y_val, model=model)
167 if y_arr.shape[0] != X_arr.shape[0]:
168 raise ValueError("Target array must have same number of rows as X")
169 return X_arr, y_arr
171 def _get_loss_fn(self) -> LossFunction:
172 """Get or initialize the loss function (always returns a valid LossFunction)."""
173 if hasattr(self, "__dataclass_fields__") and "_loss_fn" in self.__dataclass_fields__:
174 # For dataclass trainers with _loss_fn field, initialize it once
175 if not hasattr(self, "_loss_fn"):
176 loss_attr = getattr(self, "loss", None)
177 self._loss_fn = resolve_loss(loss_attr)
178 return self._loss_fn
179 else:
180 # For non-dataclass trainers, resolve on the fly
181 loss_attr = getattr(self, "loss", None)
182 return resolve_loss(loss_attr)
184 def _log_epoch(
185 self,
186 epoch_idx: int,
187 train_loss: float,
188 val_loss: float | None,
189 verbose: bool,
190 ) -> None:
191 if not verbose:
192 return
193 logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
194 message = f"Epoch {epoch_idx + 1} - train_loss: {train_loss:.6f}"
195 if val_loss is not None:
196 message += f" - val_loss: {val_loss:.6f}"
197 logger.info(message)