Coverage for anfis_toolbox / optim / sgd.py: 100%

33 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-05 18:47 -0300

1from __future__ import annotations 

2 

3from dataclasses import dataclass, field 

4from typing import Any 

5 

6import numpy as np 

7 

8from ..losses import LossFunction 

9from .base import BaseTrainer, ModelLike 

10 

11 

12@dataclass 

13class SGDTrainer(BaseTrainer): 

14 """Stochastic gradient descent trainer for ANFIS. 

15 

16 Parameters: 

17 learning_rate: Step size for gradient descent. 

18 epochs: Number of passes over the data. 

19 batch_size: Mini-batch size; if None uses full batch. 

20 shuffle: Whether to shuffle data each epoch. 

21 verbose: Whether to log progress (delegated to model logging settings). 

22 

23 Notes: 

24 Uses the configurable loss provided via ``loss`` (defaults to mean squared error). 

25 The selected loss is responsible for adapting target shapes via ``prepare_targets``. 

26 When used with ``ANFISClassifier`` and ``loss="cross_entropy"`` it trains on logits with the 

27 appropriate softmax gradient. 

28 """ 

29 

30 learning_rate: float = 0.01 

31 epochs: int = 100 

32 batch_size: None | int = None 

33 shuffle: bool = True 

34 verbose: bool = False 

35 loss: LossFunction | str | None = None 

36 _loss_fn: LossFunction = field(init=False, repr=False) 

37 

38 def init_state(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> None: 

39 """SGD has no persistent optimizer state; returns None.""" 

40 return None 

41 

42 def train_step(self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray, state: Any) -> tuple[float, Any]: 

43 """Perform one SGD step on a batch and return (loss, state).""" 

44 loss = self._compute_loss_backward_and_update(model, Xb, yb) 

45 return loss, state 

46 

47 def _compute_loss_backward_and_update(self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray) -> float: 

48 """Forward -> MSE -> backward -> update parameters; returns loss.""" 

49 loss_fn = self._get_loss_fn() 

50 model.reset_gradients() 

51 y_pred = model.forward(Xb) 

52 loss = loss_fn.loss(yb, y_pred) 

53 dL_dy = loss_fn.gradient(yb, y_pred) 

54 model.backward(dL_dy) 

55 model.update_parameters(self.learning_rate) 

56 return loss 

57 

58 def compute_loss(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> float: 

59 """Return the loss for ``(X, y)`` without mutating ``model``.""" 

60 loss_fn = self._get_loss_fn() 

61 preds = model.forward(X) 

62 return float(loss_fn.loss(y, preds))