Coverage for anfis_toolbox / optim / rmsprop.py: 100%

51 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-05 18:47 -0300

1from __future__ import annotations 

2 

3from dataclasses import dataclass, field 

4from typing import Any, cast 

5 

6import numpy as np 

7 

8from ..losses import LossFunction 

9from ._utils import ( 

10 iterate_membership_params_with_state, 

11 update_membership_param, 

12 zeros_like_structure, 

13) 

14from .base import BaseTrainer, ModelLike 

15 

16 

17@dataclass 

18class RMSPropTrainer(BaseTrainer): 

19 """RMSProp optimizer-based trainer for ANFIS. 

20 

21 Parameters: 

22 learning_rate: Base step size (alpha). 

23 rho: Exponential decay rate for the squared gradient moving average. 

24 epsilon: Small constant for numerical stability. 

25 epochs: Number of passes over the dataset. 

26 batch_size: If None, use full-batch; otherwise mini-batches of this size. 

27 shuffle: Whether to shuffle the data at each epoch when using mini-batches. 

28 verbose: Unused here; kept for API parity. 

29 

30 Notes: 

31 Supports configurable losses via the ``loss`` parameter. Defaults to mean squared error for 

32 regression tasks but can be switched to other differentiable objectives such as categorical 

33 cross-entropy when training ``ANFISClassifier`` models. 

34 """ 

35 

36 learning_rate: float = 0.001 

37 rho: float = 0.9 

38 epsilon: float = 1e-8 

39 epochs: int = 100 

40 batch_size: None | int = None 

41 shuffle: bool = True 

42 verbose: bool = False 

43 loss: LossFunction | str | None = None 

44 _loss_fn: LossFunction = field(init=False, repr=False) 

45 

46 def init_state(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> dict[str, Any]: 

47 """Initialize RMSProp caches for consequents and membership scalars.""" 

48 params = model.get_parameters() 

49 return {"params": params, "cache": zeros_like_structure(params)} 

50 

51 def train_step( 

52 self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray, state: dict[str, Any] 

53 ) -> tuple[float, dict[str, Any]]: 

54 """One RMSProp step on a batch; returns (loss, updated_state).""" 

55 loss, grads = self._compute_loss_and_grads(model, Xb, yb) 

56 self._apply_rmsprop_step(model, state["params"], state["cache"], grads) 

57 return loss, state 

58 

59 def _compute_loss_and_grads(self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray) -> tuple[float, Any]: 

60 """Forward pass, MSE loss, backward pass, and gradients for a batch. 

61 

62 Returns (loss, grads) where grads follows model.get_gradients() structure. 

63 """ 

64 loss_fn = self._get_loss_fn() 

65 model.reset_gradients() 

66 y_pred = model.forward(Xb) 

67 loss = loss_fn.loss(yb, y_pred) 

68 dL_dy = loss_fn.gradient(yb, y_pred) 

69 model.backward(dL_dy) 

70 grads = model.get_gradients() 

71 return loss, grads 

72 

73 def _apply_rmsprop_step( 

74 self, 

75 model: ModelLike, 

76 params: dict[str, Any], 

77 cache: dict[str, Any], 

78 grads: dict[str, Any], 

79 ) -> None: 

80 """Apply one RMSProp update to params using grads and caches. 

81 

82 Updates both consequent array parameters and membership scalar parameters. 

83 """ 

84 # Consequent is a numpy array 

85 g = grads["consequent"] 

86 c = cache["consequent"] 

87 c[:] = self.rho * c + (1.0 - self.rho) * (g * g) 

88 params["consequent"] = params["consequent"] - self.learning_rate * g / (np.sqrt(c) + self.epsilon) 

89 

90 # Membership parameters (scalars in nested dicts) 

91 for path, param_val, cache_val, grad in iterate_membership_params_with_state(params, cache, grads): 

92 grad = cast(float, grad) # grads_dict is provided in this context 

93 cache_val = self.rho * cache_val + (1.0 - self.rho) * (grad * grad) 

94 step = self.learning_rate * grad / (np.sqrt(cache_val) + self.epsilon) 

95 new_param = param_val - step 

96 update_membership_param(params, path, new_param) 

97 cache["membership"][path[0]][path[1]][path[2]] = cache_val 

98 

99 # Push updated params back into the model 

100 model.set_parameters(params) 

101 

102 def compute_loss(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> float: 

103 """Return the current loss value for ``(X, y)`` without modifying state.""" 

104 loss_fn = self._get_loss_fn() 

105 preds = model.forward(X) 

106 return float(loss_fn.loss(y, preds))