Coverage for anfis_toolbox / optim / hybrid_adam.py: 100%
82 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
1from __future__ import annotations
3import logging
4from copy import deepcopy
5from dataclasses import dataclass
6from typing import Any
8import numpy as np
10from ..losses import MSELoss
11from ..model import TSKANFIS
12from ._utils import iterate_membership_params_with_state, update_membership_param, zeros_like_structure
13from .base import BaseTrainer, ModelLike
16@dataclass
17class HybridAdamTrainer(BaseTrainer):
18 """Hybrid training: LSM for consequents + Adam for antecedents.
20 Notes:
21 This variant also targets the regression ANFIS. It is not compatible with the
22 classification head (:class:`~anfis_toolbox.model.TSKANFISClassifier`) or
23 :class:`~anfis_toolbox.classifier.ANFISClassifier`.
24 """
26 learning_rate: float = 0.001
27 beta1: float = 0.9
28 beta2: float = 0.999
29 epsilon: float = 1e-8
30 epochs: int = 100
31 verbose: bool = False
32 _loss_fn: MSELoss = MSELoss()
34 def init_state(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> dict[str, Any]:
35 """Initialize Adam moment tensors for membership parameters."""
36 model = self._require_regression_model(model)
37 params = model.get_parameters()
38 zero_struct = zeros_like_structure(params)["membership"]
39 return {"m": deepcopy(zero_struct), "v": deepcopy(zero_struct), "t": 0}
41 def train_step(
42 self, model: ModelLike, Xb: np.ndarray, yb: np.ndarray, state: dict[str, Any]
43 ) -> tuple[float, dict[str, Any]]:
44 """Execute one hybrid iteration combining LSM and Adam updates."""
45 model = self._require_regression_model(model)
46 model.reset_gradients()
47 Xb, yb = self._prepare_training_data(model, Xb, yb)
48 normalized_weights = model.forward_antecedents(Xb)
49 # LSM for consequents
50 ones_col = np.ones((Xb.shape[0], 1), dtype=float)
51 x_bar = np.concatenate([Xb, ones_col], axis=1)
52 A_blocks = [normalized_weights[:, j : j + 1] * x_bar for j in range(model.n_rules)]
53 A = np.concatenate(A_blocks, axis=1)
54 try:
55 regularization = 1e-6 * np.eye(A.shape[1])
56 ATA_reg = A.T @ A + regularization
57 theta = np.linalg.solve(ATA_reg, A.T @ yb.flatten())
58 except np.linalg.LinAlgError:
59 logging.getLogger(__name__).warning("Matrix singular in LSM, using pseudo-inverse")
60 theta = np.linalg.pinv(A) @ yb.flatten()
61 model.consequent_layer.parameters = theta.reshape(model.n_rules, model.n_inputs + 1)
63 # Adam for antecedents
64 y_pred = model.consequent_layer.forward(Xb, normalized_weights)
65 loss = self._loss_fn.loss(yb, y_pred)
66 dL_dy = self._loss_fn.gradient(yb, y_pred)
67 dL_dnorm_w, _ = model.consequent_layer.backward(dL_dy)
68 dL_dw = model.normalization_layer.backward(dL_dnorm_w)
69 gradients = model.rule_layer.backward(dL_dw)
70 grad_struct = model.membership_layer.backward(gradients)
71 self._apply_adam_update(model, grad_struct, state)
72 return float(loss), state
74 def _apply_adam_update(self, model: ModelLike, grad_struct: dict[str, Any], state: dict[str, Any]) -> None:
75 model = self._require_regression_model(model)
76 params = model.get_parameters()
77 m = {"membership": state["m"]}
78 v = {"membership": state["v"]}
79 t = state["t"] = state["t"] + 1
80 for path, param_val, m_val, grad in iterate_membership_params_with_state(params, m, grad_struct):
81 if grad is None:
82 continue
83 grad = float(grad)
84 m_val = self.beta1 * m_val + (1.0 - self.beta1) * grad
85 v_val = v["membership"][path[0]][path[1]][path[2]]
86 v_val = self.beta2 * v_val + (1.0 - self.beta2) * (grad * grad)
87 m_hat = m_val / (1.0 - self.beta1**t)
88 v_hat = v_val / (1.0 - self.beta2**t)
89 step = self.learning_rate * m_hat / (np.sqrt(v_hat) + self.epsilon)
90 new_param = param_val - step
91 update_membership_param(params, path, new_param)
92 m["membership"][path[0]][path[1]][path[2]] = m_val
93 v["membership"][path[0]][path[1]][path[2]] = v_val
94 model.set_parameters(params)
96 def compute_loss(self, model: ModelLike, X: np.ndarray, y: np.ndarray) -> float:
97 """Evaluate mean squared error on provided data without updates."""
98 model = self._require_regression_model(model)
99 X_arr, y_arr = self._prepare_validation_data(model, X, y)
100 normalized_weights = model.forward_antecedents(X_arr)
101 preds = model.consequent_layer.forward(X_arr, normalized_weights)
102 return float(self._loss_fn.loss(y_arr, preds))
104 @staticmethod
105 def _require_regression_model(model: ModelLike) -> TSKANFIS:
106 if not isinstance(model, TSKANFIS):
107 raise TypeError("HybridAdamTrainer supports TSKANFIS regression models only")
108 return model