Coverage for anfis_toolbox / losses.py: 100%
90 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
1"""Loss functions and their gradients for ANFIS Toolbox.
3This module centralizes the loss definitions used during training to make it
4explicit which objective is being optimized. Trainers can import from here so
5the chosen loss is clear in one place.
6"""
8from __future__ import annotations
10from typing import Any, cast
12import numpy as np
15class LossFunction:
16 """Base interface for losses used by trainers.
18 This abstract class defines the contract that all loss functions must implement.
19 Subclasses should override the `loss`, `gradient`, and optionally `prepare_targets`
20 methods to implement specific loss functions.
22 The typical workflow is:
23 1. Call `prepare_targets` to format raw targets into the expected format
24 2. Call `loss` to compute the scalar loss value
25 3. Call `gradient` to compute loss gradients for backpropagation
26 """
28 def prepare_targets(self, y: Any, *, model: Any | None = None) -> np.ndarray:
29 """Return targets in a format compatible with forward/gradient computations."""
30 return np.asarray(y, dtype=float)
32 def loss(self, y_true: np.ndarray, y_pred: np.ndarray) -> float: # pragma: no cover - interface
33 """Compute the scalar loss for the given targets and predictions."""
34 raise NotImplementedError
36 def gradient(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray: # pragma: no cover - interface
37 """Return the gradient of the loss with respect to the predictions."""
38 raise NotImplementedError
41class MSELoss(LossFunction):
42 """Mean squared error loss packaged for trainer consumption.
44 Implements the MSE loss function commonly used for regression tasks.
45 MSE measures the average squared difference between predicted and actual values.
47 The loss is defined as:
48 L = (1/n) * Σ(y_pred - y_true)²
50 And its gradient with respect to predictions is:
51 ∇L = (2/n) * (y_pred - y_true)
52 """
54 def prepare_targets(self, y: Any, *, model: Any | None = None) -> np.ndarray:
55 """Convert 1D targets into column vectors expected by MSE computations.
57 Parameters:
58 y: Array-like target values. Can be 1D or already 2D.
59 model: Optional model instance (unused for MSE).
61 Returns:
62 np.ndarray: Targets as a 2D column vector of shape (n_samples, 1).
63 """
64 y_arr = np.asarray(y, dtype=float)
65 if y_arr.ndim == 1:
66 y_arr = y_arr.reshape(-1, 1)
67 return y_arr
69 def loss(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
70 """Compute the mean squared error (MSE).
72 Parameters:
73 y_true: Array-like of true target values, shape (...,)
74 y_pred: Array-like of predicted values, same shape as y_true
76 Returns:
77 The mean of squared differences over all elements as a float.
79 Notes:
80 - Inputs are coerced to NumPy arrays with dtype=float.
81 - Broadcasting follows NumPy semantics. If shapes are not compatible
82 for element-wise subtraction, a ValueError will be raised by NumPy.
83 """
84 yt = np.asarray(y_true, dtype=float)
85 yp = np.asarray(y_pred, dtype=float)
86 diff = yt - yp
87 return float(np.mean(diff * diff))
89 def gradient(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
90 """Compute gradient of MSE with respect to predictions.
92 The gradient is computed as: ∇L = (2/n) * (y_pred - y_true)
94 Parameters:
95 y_true: True target values, shape (n_samples, 1).
96 y_pred: Predicted values, same shape as y_true.
98 Returns:
99 np.ndarray: Gradient array with same shape as y_pred.
100 """
101 yt = np.asarray(y_true, dtype=float)
102 yp = np.asarray(y_pred, dtype=float)
103 n = max(1, yt.shape[0])
104 return cast(np.ndarray, 2.0 * (yp - yt) / float(n))
107class CrossEntropyLoss(LossFunction):
108 """Categorical cross-entropy loss operating on logits.
110 Implements cross-entropy loss for multi-class classification tasks.
111 Accepts raw logits (unbounded scores) and computes numerically stable loss
112 using log-softmax formulation.
114 The loss is defined as:
115 L = -(1/n) * Σ Σ y_true[i,j] * log(softmax(logits)[i,j])
117 And its gradient with respect to logits is:
118 ∇L = (1/n) * (softmax(logits) - y_true)
120 Numerical stability is achieved through:
121 - Stable log-softmax computation in `loss` method
122 - Stable softmax via maximum subtraction in `gradient` method
123 """
125 def _stable_softmax(self, x: np.ndarray, axis: int) -> np.ndarray:
126 """Compute softmax with numerical stability.
128 Implements the numerically stable softmax by subtracting the maximum
129 value along each row before exponentiation. This prevents overflow errors
130 that would occur with large logits.
132 Formula: softmax(x) = exp(x - max(x)) / sum(exp(x - max(x)))
134 Parameters:
135 x: Input logits array, shape (..., n_classes).
136 axis: Axis along which to compute softmax (typically 1 for batch).
138 Returns:
139 np.ndarray: Normalized probabilities with same shape as input,
140 values in range [0, 1] summing to 1 along the specified axis.
141 """
142 zmax = np.max(x, axis=axis, keepdims=True)
143 exp_x = np.exp(x - zmax)
144 return cast(np.ndarray, exp_x / np.sum(exp_x, axis=axis, keepdims=True))
146 def prepare_targets(self, y: Any, *, model: Any | None = None) -> np.ndarray:
147 """Convert labels or one-hot encodings into dense float matrices.
149 Accepts either:
150 - 1D integer class labels (0 to n_classes-1)
151 - 2D one-hot encoded targets
153 If 1D labels are provided, automatically converts to one-hot encoding.
154 If model is provided with an n_classes attribute, validates consistency.
156 Parameters:
157 y: Target labels as 1D array of integers or 2D one-hot array.
158 model: Optional model instance. If provided, uses model.n_classes
159 to infer number of classes and validate dimensions.
161 Returns:
162 np.ndarray: One-hot encoded targets of shape (n_samples, n_classes).
164 Raises:
165 ValueError: If y dimension is not 1 or 2, or if dimensions don't match model.
166 """
167 y_arr = np.asarray(y)
168 if y_arr.ndim == 1:
169 n_classes_attr = getattr(model, "n_classes", None) if model is not None else None
170 if n_classes_attr is not None:
171 n_classes = int(n_classes_attr)
172 else:
173 n_classes = int(np.max(y_arr)) + 1
174 oh = np.zeros((y_arr.shape[0], n_classes), dtype=float)
175 oh[np.arange(y_arr.shape[0]), y_arr.astype(int)] = 1.0
176 return oh
177 if y_arr.ndim != 2:
178 raise ValueError("y for cross-entropy must be 1D labels or 2D one-hot encoded")
179 expected_attr = getattr(model, "n_classes", None) if model is not None else None
180 if expected_attr is not None:
181 expected = int(expected_attr)
182 if y_arr.shape[1] != expected:
183 raise ValueError(f"y one-hot must have {expected} columns")
184 return y_arr.astype(float)
186 def loss(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
187 """Compute mean cross-entropy from integer labels or one-hot vs logits.
189 Uses stable log-softmax computation to prevent numerical underflow.
190 Handles both integer class labels and one-hot encoded targets.
192 Parameters:
193 y_true: Array of shape (n_samples,) of integer class labels (0 to n_classes-1),
194 or one-hot encoded array of shape (n_samples, n_classes).
195 y_pred: Raw logit scores of shape (n_samples, n_classes).
197 Returns:
198 float: Mean cross-entropy loss across all samples.
200 Notes:
201 - Returns 0.0 if batch is empty (n_samples == 0)
202 - Numerically stable for arbitrarily large or small logit values
203 """
204 logits = np.asarray(y_pred, dtype=float)
205 n = logits.shape[0]
206 if n == 0:
207 return 0.0
208 # Stable log-softmax
209 zmax = np.max(logits, axis=1, keepdims=True)
210 logsumexp = zmax + np.log(np.sum(np.exp(logits - zmax), axis=1, keepdims=True))
211 log_probs = logits - logsumexp # (n, k)
213 yt = np.asarray(y_true)
214 if yt.ndim == 1:
215 # integer labels
216 yt = yt.reshape(-1)
217 if yt.shape[0] != n:
218 raise ValueError("y_true length must match logits batch size")
219 # pick log prob at true class
220 idx = (np.arange(n), yt.astype(int))
221 nll = -log_probs[idx]
222 else:
223 # one-hot
224 if yt.shape != logits.shape:
225 raise ValueError("For one-hot y_true, shape must match logits")
226 nll = -np.sum(yt * log_probs, axis=1)
227 return float(np.mean(nll))
229 def gradient(self, y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
230 """Compute gradient of cross-entropy with respect to logits.
232 The gradient simplifies to: softmax(logits) - one_hot(y_true)
233 This form is derived from the chain rule applied to the cross-entropy loss.
235 Accepts integer labels or one-hot encoded targets. Returns gradient
236 with the same shape as logits.
238 Parameters:
239 y_true: Array of shape (n_samples,) of integer class labels, or
240 one-hot encoded array of shape (n_samples, n_classes).
241 y_pred: Raw logit scores of shape (n_samples, n_classes).
243 Returns:
244 np.ndarray: Gradient of shape (n_samples, n_classes) with values typically
245 in range [-1, 1] indicating direction to decrease loss.
247 Raises:
248 ValueError: If one-hot y_true shape doesn't match logits shape.
249 """
250 logits = np.asarray(y_pred, dtype=float)
251 n, k = logits.shape[0], logits.shape[1]
252 yt = np.asarray(y_true)
253 if yt.ndim == 1:
254 oh = np.zeros((n, k), dtype=float)
255 oh[np.arange(n), yt.astype(int)] = 1.0
256 yt = oh
257 elif yt.shape != logits.shape:
258 raise ValueError("y_true one-hot must have same shape as logits")
259 else:
260 yt = yt.astype(float)
261 # probs
262 probs = self._stable_softmax(logits, axis=1)
263 return cast(np.ndarray, (probs - yt) / float(n))
266LOSS_REGISTRY: dict[str, type[LossFunction]] = {
267 "mse": MSELoss,
268 "mean_squared_error": MSELoss,
269 "cross_entropy": CrossEntropyLoss,
270 "crossentropy": CrossEntropyLoss,
271 "cross-entropy": CrossEntropyLoss,
272}
275def resolve_loss(loss: str | LossFunction | None) -> LossFunction:
276 """Resolve user-provided loss spec into a concrete ``LossFunction`` instance.
278 Provides flexible loss specification allowing string names, instances, or None.
280 Parameters:
281 loss: Loss specification as one of:
282 - None: Returns MSELoss() as default
283 - str: Key from LOSS_REGISTRY (case-insensitive)
284 - LossFunction: Returned as-is
286 Returns:
287 LossFunction: Instantiated loss function ready for use.
289 Raises:
290 ValueError: If string loss is not in LOSS_REGISTRY.
291 TypeError: If loss is not None, str, or LossFunction instance.
293 Examples:
294 >>> loss1 = resolve_loss(None) # Returns MSELoss()
295 >>> loss2 = resolve_loss("mse")
296 >>> loss3 = resolve_loss("cross_entropy")
297 >>> loss4 = resolve_loss(CrossEntropyLoss())
298 """
299 if loss is None:
300 return MSELoss()
301 if isinstance(loss, LossFunction):
302 return loss
303 if isinstance(loss, str):
304 key = loss.lower()
305 if key not in LOSS_REGISTRY:
306 raise ValueError(f"Unknown loss '{loss}'. Available: {sorted(LOSS_REGISTRY)}")
307 return LOSS_REGISTRY[key]()
308 raise TypeError("loss must be None, str, or a LossFunction instance")
311__all__ = [
312 "LossFunction",
313 "MSELoss",
314 "CrossEntropyLoss",
315 "resolve_loss",
316]