Coverage for anfis_toolbox / optim / adam.py: 100%
64 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
1from __future__ import annotations
3from dataclasses import dataclass, field
4from typing import Any, cast
6import numpy as np
8from ..losses import LossFunction
9from ._utils import (
10 iterate_membership_params_with_state,
11 update_membership_param,
12 zeros_like_structure,
13)
14from .base import BaseTrainer, ModelLike
17def _adam_update(
18 param: np.ndarray,
19 grad: np.ndarray,
20 m: np.ndarray,
21 v: np.ndarray,
22 lr: float,
23 beta1: float,
24 beta2: float,
25 eps: float,
26 t: int,
27) -> None:
28 """Compute Adam update for numpy arrays (param, grad, m, v)."""
29 m[:] = beta1 * m + (1.0 - beta1) * grad
30 v[:] = beta2 * v + (1.0 - beta2) * (grad * grad)
31 m_hat = m / (1.0 - beta1**t)
32 v_hat = v / (1.0 - beta2**t)
33 param[:] = param - lr * m_hat / (np.sqrt(v_hat) + eps)
36@dataclass
37class AdamTrainer(BaseTrainer):
38 """Adam optimizer-based trainer for ANFIS.
40 Parameters:
41 learning_rate: Base step size (alpha).
42 beta1: Exponential decay rate for the first moment estimates.
43 beta2: Exponential decay rate for the second moment estimates.
44 epsilon: Small constant for numerical stability.
45 epochs: Number of passes over the dataset.
46 batch_size: If None, use full-batch; otherwise mini-batches of this size.
47 shuffle: Whether to shuffle the data at each epoch when using mini-batches.
48 verbose: Unused here; kept for API parity.
50 Notes:
51 Supports configurable losses via the ``loss`` parameter. Defaults to mean squared error for
52 regression, but can minimize other differentiable objectives such as categorical
53 cross-entropy when used with ``ANFISClassifier``.
54 """
56 learning_rate: float = 0.001
57 beta1: float = 0.9
58 beta2: float = 0.999
59 epsilon: float = 1e-8
60 epochs: int = 100
61 batch_size: None | int = None
62 shuffle: bool = True
63 verbose: bool = False
64 loss: LossFunction | str | None = None
65 _loss_fn: LossFunction = field(init=False, repr=False)
67 def init_state(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> dict[str, Any]:
68 """Initialize Adam's first and second moments and time step.
70 Returns a dict with keys: params, m, v, t.
71 """
72 params = model.get_parameters()
73 return {
74 "params": params,
75 "m": zeros_like_structure(params),
76 "v": zeros_like_structure(params),
77 "t": 0,
78 }
80 def train_step(
81 self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray, state: dict[str, Any]
82 ) -> tuple[float, dict[str, Any]]:
83 """One Adam step on a batch; returns (loss, updated_state)."""
84 loss, grads = self._compute_loss_and_grads(model, Xb, yb)
85 t_val = cast(int, state["t"])
86 t_new = self._apply_adam_step(model, state["params"], grads, state["m"], state["v"], t_val)
87 state["t"] = t_new
88 return loss, state
90 def _compute_loss_and_grads(self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray) -> tuple[float, Any]:
91 """Forward pass, MSE loss, backward pass, and gradients for a batch.
93 Returns (loss, grads) where grads follows model.get_gradients() structure.
94 """
95 loss_fn = self._get_loss_fn()
96 model.reset_gradients()
97 y_pred = model.forward(Xb)
98 loss = loss_fn.loss(yb, y_pred)
99 dL_dy = loss_fn.gradient(yb, y_pred)
100 model.backward(dL_dy)
101 grads = model.get_gradients()
102 return loss, grads
104 def _apply_adam_step(
105 self,
106 model: Any,
107 params: dict[str, Any],
108 grads: dict[str, Any],
109 m: dict[str, Any],
110 v: dict[str, Any],
111 t: int,
112 ) -> int:
113 """Apply one Adam update to params using grads and moments; returns new time step.
115 Updates both consequent array parameters and membership scalar parameters.
116 """
117 t += 1
118 _adam_update(
119 params["consequent"],
120 grads["consequent"],
121 m["consequent"],
122 v["consequent"],
123 self.learning_rate,
124 self.beta1,
125 self.beta2,
126 self.epsilon,
127 t,
128 )
129 # Membership parameters (scalars in nested dicts)
130 for path, param_val, m_val, grad in iterate_membership_params_with_state(params, m, grads):
131 grad = cast(float, grad) # grads_dict is provided in this context
132 m_val = self.beta1 * m_val + (1.0 - self.beta1) * grad
133 v_val = v["membership"][path[0]][path[1]][path[2]]
134 v_val = self.beta2 * v_val + (1.0 - self.beta2) * (grad * grad)
135 m_hat = m_val / (1.0 - self.beta1**t)
136 v_hat = v_val / (1.0 - self.beta2**t)
137 step = self.learning_rate * m_hat / (np.sqrt(v_hat) + self.epsilon)
138 new_param = param_val - step
139 update_membership_param(params, path, new_param)
140 m["membership"][path[0]][path[1]][path[2]] = m_val
141 v["membership"][path[0]][path[1]][path[2]] = v_val
143 # Push updated params back into the model
144 model.set_parameters(params)
145 return t
147 def compute_loss(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> float:
148 """Evaluate the configured loss on ``(X, y)`` without updating parameters."""
149 loss_fn = self._get_loss_fn()
150 preds = model.forward(X)
151 return float(loss_fn.loss(y, preds))