Coverage for anfis_toolbox / estimator_utils.py: 100%
161 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
1"""Lightweight utilities for scikit-learn style estimators without external dependencies.
3This module provides a minimal subset of the scikit-learn estimator contract so that
4high-level ANFIS interfaces can expose familiar methods (`fit`, `predict`,
5`get_params`, `set_params`, etc.) without requiring scikit-learn as a runtime
6dependency. The helpers here intentionally implement only the pieces we need
7and keep them Numpy-centric for portability.
8"""
10from __future__ import annotations
12from collections.abc import Iterable
13from copy import deepcopy
14from dataclasses import dataclass
15from typing import Any
17import numpy as np
19try: # pragma: no cover - optional dependency
20 from sklearn.utils._tags import Tags, TargetTags
22 _SKLEARN_TAGS_AVAILABLE = True
23except Exception: # pragma: no cover - sklearn not installed
24 Tags = None
25 TargetTags = None
26 _SKLEARN_TAGS_AVAILABLE = False
28__all__ = [
29 "BaseEstimatorLike",
30 "RegressorMixinLike",
31 "ClassifierMixinLike",
32 "FittedMixin",
33 "RuleInspectorMixin",
34 "NotFittedError",
35 "check_is_fitted",
36 "ensure_2d_array",
37 "ensure_vector",
38 "infer_feature_names",
39 "format_estimator_repr",
40]
43class NotFittedError(RuntimeError):
44 """Exception raised when an estimator is used before fitting."""
47class BaseEstimatorLike:
48 """Mixin implementing scikit-learn style parameter inspection.
50 Parameters are assumed to live on the instance `__dict__` and be declared in
51 `__init__`. This matches the common sklearn design pattern and enables
52 cloning/grid-search like workflows without relying on sklearn itself.
53 """
55 def get_params(self, deep: bool = True) -> dict[str, Any]:
56 """Return estimator parameters following sklearn conventions."""
58 def clone_param(value: Any) -> Any:
59 if isinstance(value, dict):
60 return {k: clone_param(v) for k, v in value.items()}
61 if isinstance(value, (list, tuple)):
62 return type(value)(clone_param(v) for v in value)
63 # Primitive / numpy scalars
64 if isinstance(value, (str, int, float, bool, type(None), np.generic)):
65 return value
66 # Fallback to deepcopy for custom objects
67 return deepcopy(value)
69 return {key: clone_param(value) for key, value in self.__dict__.items() if not key.endswith("_")}
71 def set_params(self, **params: Any) -> BaseEstimatorLike:
72 """Set estimator parameters and return self."""
73 for key, value in params.items():
74 if not hasattr(self, key):
75 raise ValueError(f"Invalid parameter '{key}' for {type(self).__name__}.")
76 setattr(self, key, value)
77 return self
79 # ------------------------------------------------------------------
80 # scikit-learn compatibility hooks
81 # ------------------------------------------------------------------
82 def __sklearn_tags__(self) -> dict[str, Any] | Any:
83 """Return estimator capability tags expected by scikit-learn."""
84 merged: dict[str, Any] = {}
85 more_tags = getattr(self, "_more_tags", None)
86 if callable(more_tags):
87 merged.update(more_tags())
88 extra_tags = getattr(self, "_sklearn_tags", None)
89 if isinstance(extra_tags, dict):
90 merged.update(extra_tags)
92 if _SKLEARN_TAGS_AVAILABLE:
93 # Construct a Tags object with proper defaults
94 estimator_type = merged.pop("estimator_type", None)
95 non_deterministic = merged.pop("non_deterministic", False)
96 requires_fit = merged.pop("requires_fit", True)
97 requires_y = merged.pop("requires_y", True)
99 # Create target_tags
100 target_tags = TargetTags(required=requires_y)
102 # Create Tags object
103 tags = Tags(
104 estimator_type=estimator_type,
105 target_tags=target_tags,
106 transformer_tags=None,
107 classifier_tags=None,
108 regressor_tags=None,
109 non_deterministic=non_deterministic,
110 requires_fit=requires_fit,
111 )
113 # Remaining keys are not recognised by the public Tags API; ignore gracefully.
114 return tags
116 # Fallback lightweight representation when scikit-learn is not available.
117 fallback = {
118 "estimator_type": merged.pop("estimator_type", None),
119 "non_deterministic": merged.pop("non_deterministic", False),
120 "requires_y": merged.pop("requires_y", True),
121 "requires_fit": merged.pop("requires_fit", True),
122 }
123 fallback.update(merged)
124 return fallback
126 def __sklearn_is_fitted__(self) -> bool:
127 """Expose estimator fitted state to scikit-learn utilities."""
128 return bool(getattr(self, "is_fitted_", False))
131class FittedMixin:
132 """Mixin providing utility to guard against using estimators pre-fit."""
134 def _mark_fitted(self) -> None:
135 self.is_fitted_ = True
137 def _require_is_fitted(self, attributes: Iterable[str] | None = None) -> None:
138 if not getattr(self, "is_fitted_", False):
139 raise NotFittedError(f"{type(self).__name__} instance is not fitted yet.")
140 if attributes:
141 missing = [attr for attr in attributes if not hasattr(self, attr)]
142 if missing:
143 raise NotFittedError(
144 f"Estimator {type(self).__name__} is missing fitted attribute(s): {', '.join(missing)}"
145 )
148def check_is_fitted(estimator: FittedMixin, attributes: Iterable[str] | None = None) -> None:
149 """Check if the estimator is fitted by verifying `is_fitted_` and optional attributes."""
150 estimator._require_is_fitted(attributes)
153def format_estimator_repr(
154 name: str,
155 config_pairs: Iterable[tuple[str, Any]],
156 children: Iterable[tuple[str, str]],
157 *,
158 ascii_only: bool = False,
159) -> str:
160 r"""Compose a tree-style ``__repr__`` string for estimators.
162 Parameters
163 ----------
164 name : str
165 The display name for the estimator (typically ``type(self).__name__``).
166 config_pairs : Iterable[tuple[str, Any]]
167 Sequence of ``(key, value)`` pairs describing configuration values. ``None``
168 values are omitted automatically. The values are rendered with ``repr`` to
169 preserve type information (strings quoted, etc.).
170 children : Iterable[tuple[str, str]]
171 Sequence of ``(label, description)`` items describing child artefacts, such as
172 fitted submodels or optimizers. Descriptions may contain newlines; they will
173 be indented appropriately beneath the child label.
174 ascii_only : bool, default=False
175 When ``True`` use ASCII connectors (``|--``/``\--``) instead of box drawing
176 characters. Automatically useful for environments that do not render Unicode
177 well.
179 Returns:
180 -------
181 str
182 Multi-line string representation combining the header and child sections.
183 """
184 config_fragments = [f"{key}={value!r}" for key, value in config_pairs if value is not None]
185 header = f"{name}({', '.join(config_fragments)})" if config_fragments else f"{name}()"
187 child_list = list(children)
188 if not child_list:
189 return header
191 branch_mid, branch_last, pad_mid, pad_last = (
192 ("|--", "\\--", "| ", " ") if ascii_only else ("├─", "└─", "│ ", " ")
193 )
195 lines = [header]
196 total = len(child_list)
197 for index, (label, description) in enumerate(child_list):
198 is_last = index == total - 1
199 branch = branch_last if is_last else branch_mid
200 pad = pad_last if is_last else pad_mid
201 desc_lines = str(description).splitlines() or [""]
202 first = desc_lines[0]
203 lines.append(f"{branch} {label}: {first}")
204 if len(desc_lines) > 1:
205 padding = f"{pad} "
206 for extra in desc_lines[1:]:
207 lines.append(f"{padding}{extra}")
209 return "\n".join(lines)
212class RuleInspectorMixin(FittedMixin):
213 """Mixin that exposes ANFIS rule descriptors for fitted estimators."""
215 def get_rules(
216 self,
217 *,
218 include_membership_functions: bool = False,
219 ) -> list[tuple[int, ...]] | list[dict[str, Any]]:
220 """Return the fuzzy rules learned by the estimator.
222 Parameters
223 ----------
224 include_membership_functions : bool, default=False
225 When ``False`` (default), return a list of tuples with the membership-function
226 indices per input. When ``True``, return a list of dictionaries describing each
227 rule with input names, membership-function indices, and their corresponding
228 membership function instances (when available).
230 Returns:
231 -------
232 list
233 Rule definitions either as tuples of integers (default) or as dictionaries with a
234 rich description if ``include_membership_functions`` is ``True``.
235 """
236 check_is_fitted(self, attributes=["rules_"])
238 raw_rules = getattr(self, "rules_", None) or []
239 rule_tuples = [tuple(rule) for rule in raw_rules]
241 if not include_membership_functions:
242 return rule_tuples
244 model = getattr(self, "model_", None)
245 if model is None:
246 raise NotFittedError(f"{type(self).__name__} instance is not fitted yet.")
248 membership_map = getattr(model, "membership_functions", {})
249 feature_names = list(getattr(self, "feature_names_in_", []) or membership_map.keys())
251 descriptors: list[dict[str, Any]] = []
252 for rule_index, rule in enumerate(rule_tuples):
253 antecedents: list[dict[str, Any]] = []
254 for input_index, mf_index in enumerate(rule):
255 if input_index < len(feature_names):
256 input_name = feature_names[input_index]
257 else: # Fallback to positional naming
258 input_name = f"x{input_index + 1}"
260 mf_list = membership_map.get(input_name, [])
261 membership_fn = None
262 if 0 <= mf_index < len(mf_list):
263 membership_fn = mf_list[mf_index]
265 antecedents.append(
266 {
267 "input": input_name,
268 "mf_index": int(mf_index),
269 "membership_function": membership_fn,
270 }
271 )
273 descriptors.append(
274 {
275 "index": rule_index,
276 "rule": rule,
277 "antecedents": antecedents,
278 }
279 )
281 return descriptors
284class RegressorMixinLike:
285 """Mixin implementing a default `score` method for regressors."""
287 def predict(self, X: np.ndarray) -> np.ndarray: # pragma: no cover - interface
288 """Return predicted targets for ``X``."""
289 raise NotImplementedError
291 def score(self, X: np.ndarray, y: np.ndarray) -> float:
292 """Return the coefficient of determination R^2 of the prediction."""
293 y_true = np.asarray(y, dtype=float).reshape(-1)
294 y_pred = np.asarray(self.predict(X), dtype=float).reshape(-1)
295 if y_true.shape != y_pred.shape:
296 raise ValueError("Predicted values have a different shape than y.")
297 ss_res = float(np.sum((y_true - y_pred) ** 2))
298 y_mean = float(np.mean(y_true))
299 ss_tot = float(np.sum((y_true - y_mean) ** 2))
300 if ss_tot == 0.0:
301 return 0.0
302 return 1.0 - ss_res / ss_tot
305class ClassifierMixinLike:
306 """Mixin implementing default `score` via simple accuracy."""
308 def predict(self, X: np.ndarray) -> np.ndarray: # pragma: no cover - interface
309 """Return predicted class labels for ``X``."""
310 raise NotImplementedError
312 def score(self, X: np.ndarray, y: np.ndarray) -> float:
313 """Return the mean accuracy on the given test data and labels."""
314 y_true = np.asarray(y)
315 y_pred = np.asarray(self.predict(X))
316 if y_true.shape != y_pred.shape:
317 raise ValueError("Predicted values have a different shape than y.")
318 if y_true.size == 0:
319 return 0.0
320 return float(np.mean(y_true == y_pred))
323@dataclass
324class ValidationResult:
325 X: np.ndarray
326 y: np.ndarray | None
327 feature_names: list[str]
330def ensure_2d_array(X: Any) -> tuple[np.ndarray, list[str]]:
331 """Validate and convert input data to a 2D float64 numpy array."""
332 if hasattr(X, "to_numpy"):
333 values = X.to_numpy(dtype=float)
334 names = getattr(X, "columns", None)
335 feature_names = [str(col) for col in names] if names is not None else None
336 else:
337 values = np.asarray(X, dtype=float)
338 feature_names = None
340 if values.ndim != 2:
341 raise ValueError("Input data must be 2-dimensional (n_samples, n_features).")
342 if feature_names is None:
343 feature_names = [f"x{i + 1}" for i in range(values.shape[1])]
345 return values, feature_names
348def ensure_vector(y: Any, *, allow_2d_column: bool = True) -> np.ndarray:
349 """Validate and convert target data to a 1D numpy array."""
350 array = np.asarray(y)
351 if array.ndim == 2:
352 if array.shape[1] == 1 and allow_2d_column:
353 array = array.reshape(-1)
354 else:
355 raise ValueError("Target array must be 1-dimensional or a column vector.")
356 elif array.ndim != 1:
357 raise ValueError("Target array must be 1-dimensional.")
358 return array
361# Backwards compatibility aliases (deprecated; prefer the public variants above)
362_ensure_2d_array = ensure_2d_array
363_ensure_vector = ensure_vector
366def infer_feature_names(X: Any) -> list[str]:
367 """Return feature names inferred from the input data structure."""
368 if hasattr(X, "columns"):
369 return [str(col) for col in X.columns]
370 X_arr = np.asarray(X)
371 if X_arr.ndim != 2:
372 raise ValueError("Expected 2D array-like input to infer feature names.")
373 return [f"x{i + 1}" for i in range(X_arr.shape[1])]