Coverage for anfis_toolbox / optim / rmsprop.py: 100%
51 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
1from __future__ import annotations
3from dataclasses import dataclass, field
4from typing import Any, cast
6import numpy as np
8from ..losses import LossFunction
9from ._utils import (
10 iterate_membership_params_with_state,
11 update_membership_param,
12 zeros_like_structure,
13)
14from .base import BaseTrainer, ModelLike
17@dataclass
18class RMSPropTrainer(BaseTrainer):
19 """RMSProp optimizer-based trainer for ANFIS.
21 Parameters:
22 learning_rate: Base step size (alpha).
23 rho: Exponential decay rate for the squared gradient moving average.
24 epsilon: Small constant for numerical stability.
25 epochs: Number of passes over the dataset.
26 batch_size: If None, use full-batch; otherwise mini-batches of this size.
27 shuffle: Whether to shuffle the data at each epoch when using mini-batches.
28 verbose: Unused here; kept for API parity.
30 Notes:
31 Supports configurable losses via the ``loss`` parameter. Defaults to mean squared error for
32 regression tasks but can be switched to other differentiable objectives such as categorical
33 cross-entropy when training ``ANFISClassifier`` models.
34 """
36 learning_rate: float = 0.001
37 rho: float = 0.9
38 epsilon: float = 1e-8
39 epochs: int = 100
40 batch_size: None | int = None
41 shuffle: bool = True
42 verbose: bool = False
43 loss: LossFunction | str | None = None
44 _loss_fn: LossFunction = field(init=False, repr=False)
46 def init_state(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> dict[str, Any]:
47 """Initialize RMSProp caches for consequents and membership scalars."""
48 params = model.get_parameters()
49 return {"params": params, "cache": zeros_like_structure(params)}
51 def train_step(
52 self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray, state: dict[str, Any]
53 ) -> tuple[float, dict[str, Any]]:
54 """One RMSProp step on a batch; returns (loss, updated_state)."""
55 loss, grads = self._compute_loss_and_grads(model, Xb, yb)
56 self._apply_rmsprop_step(model, state["params"], state["cache"], grads)
57 return loss, state
59 def _compute_loss_and_grads(self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray) -> tuple[float, Any]:
60 """Forward pass, MSE loss, backward pass, and gradients for a batch.
62 Returns (loss, grads) where grads follows model.get_gradients() structure.
63 """
64 loss_fn = self._get_loss_fn()
65 model.reset_gradients()
66 y_pred = model.forward(Xb)
67 loss = loss_fn.loss(yb, y_pred)
68 dL_dy = loss_fn.gradient(yb, y_pred)
69 model.backward(dL_dy)
70 grads = model.get_gradients()
71 return loss, grads
73 def _apply_rmsprop_step(
74 self,
75 model: ModelLike,
76 params: dict[str, Any],
77 cache: dict[str, Any],
78 grads: dict[str, Any],
79 ) -> None:
80 """Apply one RMSProp update to params using grads and caches.
82 Updates both consequent array parameters and membership scalar parameters.
83 """
84 # Consequent is a numpy array
85 g = grads["consequent"]
86 c = cache["consequent"]
87 c[:] = self.rho * c + (1.0 - self.rho) * (g * g)
88 params["consequent"] = params["consequent"] - self.learning_rate * g / (np.sqrt(c) + self.epsilon)
90 # Membership parameters (scalars in nested dicts)
91 for path, param_val, cache_val, grad in iterate_membership_params_with_state(params, cache, grads):
92 grad = cast(float, grad) # grads_dict is provided in this context
93 cache_val = self.rho * cache_val + (1.0 - self.rho) * (grad * grad)
94 step = self.learning_rate * grad / (np.sqrt(cache_val) + self.epsilon)
95 new_param = param_val - step
96 update_membership_param(params, path, new_param)
97 cache["membership"][path[0]][path[1]][path[2]] = cache_val
99 # Push updated params back into the model
100 model.set_parameters(params)
102 def compute_loss(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> float:
103 """Return the current loss value for ``(X, y)`` without modifying state."""
104 loss_fn = self._get_loss_fn()
105 preds = model.forward(X)
106 return float(loss_fn.loss(y, preds))