Coverage for anfis_toolbox / optim / adam.py: 100%

64 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-05 18:47 -0300

1from __future__ import annotations 

2 

3from dataclasses import dataclass, field 

4from typing import Any, cast 

5 

6import numpy as np 

7 

8from ..losses import LossFunction 

9from ._utils import ( 

10 iterate_membership_params_with_state, 

11 update_membership_param, 

12 zeros_like_structure, 

13) 

14from .base import BaseTrainer, ModelLike 

15 

16 

17def _adam_update( 

18 param: np.ndarray, 

19 grad: np.ndarray, 

20 m: np.ndarray, 

21 v: np.ndarray, 

22 lr: float, 

23 beta1: float, 

24 beta2: float, 

25 eps: float, 

26 t: int, 

27) -> None: 

28 """Compute Adam update for numpy arrays (param, grad, m, v).""" 

29 m[:] = beta1 * m + (1.0 - beta1) * grad 

30 v[:] = beta2 * v + (1.0 - beta2) * (grad * grad) 

31 m_hat = m / (1.0 - beta1**t) 

32 v_hat = v / (1.0 - beta2**t) 

33 param[:] = param - lr * m_hat / (np.sqrt(v_hat) + eps) 

34 

35 

36@dataclass 

37class AdamTrainer(BaseTrainer): 

38 """Adam optimizer-based trainer for ANFIS. 

39 

40 Parameters: 

41 learning_rate: Base step size (alpha). 

42 beta1: Exponential decay rate for the first moment estimates. 

43 beta2: Exponential decay rate for the second moment estimates. 

44 epsilon: Small constant for numerical stability. 

45 epochs: Number of passes over the dataset. 

46 batch_size: If None, use full-batch; otherwise mini-batches of this size. 

47 shuffle: Whether to shuffle the data at each epoch when using mini-batches. 

48 verbose: Unused here; kept for API parity. 

49 

50 Notes: 

51 Supports configurable losses via the ``loss`` parameter. Defaults to mean squared error for 

52 regression, but can minimize other differentiable objectives such as categorical 

53 cross-entropy when used with ``ANFISClassifier``. 

54 """ 

55 

56 learning_rate: float = 0.001 

57 beta1: float = 0.9 

58 beta2: float = 0.999 

59 epsilon: float = 1e-8 

60 epochs: int = 100 

61 batch_size: None | int = None 

62 shuffle: bool = True 

63 verbose: bool = False 

64 loss: LossFunction | str | None = None 

65 _loss_fn: LossFunction = field(init=False, repr=False) 

66 

67 def init_state(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> dict[str, Any]: 

68 """Initialize Adam's first and second moments and time step. 

69 

70 Returns a dict with keys: params, m, v, t. 

71 """ 

72 params = model.get_parameters() 

73 return { 

74 "params": params, 

75 "m": zeros_like_structure(params), 

76 "v": zeros_like_structure(params), 

77 "t": 0, 

78 } 

79 

80 def train_step( 

81 self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray, state: dict[str, Any] 

82 ) -> tuple[float, dict[str, Any]]: 

83 """One Adam step on a batch; returns (loss, updated_state).""" 

84 loss, grads = self._compute_loss_and_grads(model, Xb, yb) 

85 t_val = cast(int, state["t"]) 

86 t_new = self._apply_adam_step(model, state["params"], grads, state["m"], state["v"], t_val) 

87 state["t"] = t_new 

88 return loss, state 

89 

90 def _compute_loss_and_grads(self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray) -> tuple[float, Any]: 

91 """Forward pass, MSE loss, backward pass, and gradients for a batch. 

92 

93 Returns (loss, grads) where grads follows model.get_gradients() structure. 

94 """ 

95 loss_fn = self._get_loss_fn() 

96 model.reset_gradients() 

97 y_pred = model.forward(Xb) 

98 loss = loss_fn.loss(yb, y_pred) 

99 dL_dy = loss_fn.gradient(yb, y_pred) 

100 model.backward(dL_dy) 

101 grads = model.get_gradients() 

102 return loss, grads 

103 

104 def _apply_adam_step( 

105 self, 

106 model: Any, 

107 params: dict[str, Any], 

108 grads: dict[str, Any], 

109 m: dict[str, Any], 

110 v: dict[str, Any], 

111 t: int, 

112 ) -> int: 

113 """Apply one Adam update to params using grads and moments; returns new time step. 

114 

115 Updates both consequent array parameters and membership scalar parameters. 

116 """ 

117 t += 1 

118 _adam_update( 

119 params["consequent"], 

120 grads["consequent"], 

121 m["consequent"], 

122 v["consequent"], 

123 self.learning_rate, 

124 self.beta1, 

125 self.beta2, 

126 self.epsilon, 

127 t, 

128 ) 

129 # Membership parameters (scalars in nested dicts) 

130 for path, param_val, m_val, grad in iterate_membership_params_with_state(params, m, grads): 

131 grad = cast(float, grad) # grads_dict is provided in this context 

132 m_val = self.beta1 * m_val + (1.0 - self.beta1) * grad 

133 v_val = v["membership"][path[0]][path[1]][path[2]] 

134 v_val = self.beta2 * v_val + (1.0 - self.beta2) * (grad * grad) 

135 m_hat = m_val / (1.0 - self.beta1**t) 

136 v_hat = v_val / (1.0 - self.beta2**t) 

137 step = self.learning_rate * m_hat / (np.sqrt(v_hat) + self.epsilon) 

138 new_param = param_val - step 

139 update_membership_param(params, path, new_param) 

140 m["membership"][path[0]][path[1]][path[2]] = m_val 

141 v["membership"][path[0]][path[1]][path[2]] = v_val 

142 

143 # Push updated params back into the model 

144 model.set_parameters(params) 

145 return t 

146 

147 def compute_loss(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> float: 

148 """Evaluate the configured loss on ``(X, y)`` without updating parameters.""" 

149 loss_fn = self._get_loss_fn() 

150 preds = model.forward(X) 

151 return float(loss_fn.loss(y, preds))