Coverage for anfis_toolbox / optim / sgd.py: 100%
33 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
1from __future__ import annotations
3from dataclasses import dataclass, field
4from typing import Any
6import numpy as np
8from ..losses import LossFunction
9from .base import BaseTrainer, ModelLike
12@dataclass
13class SGDTrainer(BaseTrainer):
14 """Stochastic gradient descent trainer for ANFIS.
16 Parameters:
17 learning_rate: Step size for gradient descent.
18 epochs: Number of passes over the data.
19 batch_size: Mini-batch size; if None uses full batch.
20 shuffle: Whether to shuffle data each epoch.
21 verbose: Whether to log progress (delegated to model logging settings).
23 Notes:
24 Uses the configurable loss provided via ``loss`` (defaults to mean squared error).
25 The selected loss is responsible for adapting target shapes via ``prepare_targets``.
26 When used with ``ANFISClassifier`` and ``loss="cross_entropy"`` it trains on logits with the
27 appropriate softmax gradient.
28 """
30 learning_rate: float = 0.01
31 epochs: int = 100
32 batch_size: None | int = None
33 shuffle: bool = True
34 verbose: bool = False
35 loss: LossFunction | str | None = None
36 _loss_fn: LossFunction = field(init=False, repr=False)
38 def init_state(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> None:
39 """SGD has no persistent optimizer state; returns None."""
40 return None
42 def train_step(self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray, state: Any) -> tuple[float, Any]:
43 """Perform one SGD step on a batch and return (loss, state)."""
44 loss = self._compute_loss_backward_and_update(model, Xb, yb)
45 return loss, state
47 def _compute_loss_backward_and_update(self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray) -> float:
48 """Forward -> MSE -> backward -> update parameters; returns loss."""
49 loss_fn = self._get_loss_fn()
50 model.reset_gradients()
51 y_pred = model.forward(Xb)
52 loss = loss_fn.loss(yb, y_pred)
53 dL_dy = loss_fn.gradient(yb, y_pred)
54 model.backward(dL_dy)
55 model.update_parameters(self.learning_rate)
56 return loss
58 def compute_loss(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> float:
59 """Return the loss for ``(X, y)`` without mutating ``model``."""
60 loss_fn = self._get_loss_fn()
61 preds = model.forward(X)
62 return float(loss_fn.loss(y, preds))