Coverage for anfis_toolbox / regressor.py: 100%
348 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
1"""High-level regression estimator facade for ANFIS.
3The :class:`ANFISRegressor` provides a scikit-learn style interface that wires
4up membership-function generation, model construction, and optimizer selection
5at instantiation time. It reuses the low-level :mod:`anfis_toolbox` components
6under the hood without introducing an external dependency on scikit-learn.
7"""
9from __future__ import annotations
11import inspect
12import logging
13import pickle # nosec B403
14from collections.abc import Mapping, Sequence
15from copy import deepcopy
16from pathlib import Path
17from typing import Any, TypeAlias, cast
19import numpy as np
20import numpy.typing as npt
22from .builders import ANFISBuilder
23from .estimator_utils import (
24 BaseEstimatorLike,
25 FittedMixin,
26 RegressorMixinLike,
27 check_is_fitted,
28 ensure_2d_array,
29 ensure_vector,
30 format_estimator_repr,
31)
32from .logging_config import enable_training_logs
33from .losses import LossFunction
34from .membership import MembershipFunction
35from .metrics import ANFISMetrics, MetricValue
36from .model import TSKANFIS, TrainingHistory
37from .optim import (
38 AdamTrainer,
39 BaseTrainer,
40 HybridAdamTrainer,
41 HybridTrainer,
42 PSOTrainer,
43 RMSPropTrainer,
44 SGDTrainer,
45)
47InputConfigValue: TypeAlias = Mapping[str, Any] | Sequence[Any] | MembershipFunction | str | int | None
48NormalizedInputSpec: TypeAlias = dict[str, Any]
50TRAINER_REGISTRY: dict[str, type[BaseTrainer]] = {
51 "hybrid": HybridTrainer,
52 "hybrid_adam": HybridAdamTrainer,
53 "sgd": SGDTrainer,
54 "adam": AdamTrainer,
55 "rmsprop": RMSPropTrainer,
56 "pso": PSOTrainer,
57}
60def _ensure_training_logging(verbose: bool) -> None:
61 if not verbose:
62 return
63 logger = logging.getLogger("anfis_toolbox")
64 if logger.handlers:
65 return
66 enable_training_logs()
69class ANFISRegressor(BaseEstimatorLike, FittedMixin, RegressorMixinLike):
70 """Adaptive Neuro-Fuzzy regressor with a scikit-learn style API.
72 The estimator manages membership-function synthesis, rule construction, and
73 trainer selection so you can focus on calling :meth:`fit`, :meth:`predict`,
74 and :meth:`evaluate` with familiar NumPy-like data structures.
76 Examples:
77 --------
78 >>> reg = ANFISRegressor()
79 >>> reg.fit(X, y)
80 ANFISRegressor(...)
81 >>> reg.predict(X[:1])
82 array([...])
84 Parameters
85 ----------
86 n_mfs : int, default=3
87 Default number of membership functions per input.
88 mf_type : str, default="gaussian"
89 Default membership function family used for automatically generated
90 membership functions. Supported values include ``"gaussian"``,
91 ``"triangular"``, ``"bell"``, and other names exposed by the
92 membership catalogue.
93 init : {"grid", "fcm", "random", None}, default="grid"
94 Strategy used when inferring membership functions from data. ``None``
95 falls back to ``"grid"``.
96 overlap : float, default=0.5
97 Controls overlap when generating membership functions automatically.
98 margin : float, default=0.10
99 Margin added around observed data ranges during automatic
100 initialization.
101 inputs_config : Mapping, optional
102 Per-input overrides. Keys may be feature names (when ``X`` is a
103 :class:`pandas.DataFrame`) or integer indices. Values may be:
105 * ``dict`` with keys among ``{"n_mfs", "mf_type", "init", "overlap",
106 "margin", "range", "membership_functions", "mfs"}``.
107 * A list/tuple of :class:`MembershipFunction` instances for full control.
108 * ``None`` for defaults.
109 random_state : int, optional
110 Random state forwarded to FCM-based initialization and any stochastic
111 optimizers.
112 optimizer : str, BaseTrainer, type[BaseTrainer], or None, default="hybrid"
113 Trainer identifier or instance used for fitting. Strings map to entries
114 in :data:`TRAINER_REGISTRY`. ``None`` defaults to "hybrid".
115 optimizer_params : Mapping, optional
116 Additional keyword arguments forwarded to the trainer constructor.
117 learning_rate, epochs, batch_size, shuffle, verbose : optional scalars
118 Common trainer hyper-parameters provided for convenience. When the
119 selected trainer supports the parameter it is included automatically.
120 loss : str or LossFunction, optional
121 Custom loss forwarded to trainers that expose a ``loss`` parameter.
122 rules : Sequence[Sequence[int]] | None, optional
123 Explicit fuzzy rule indices to use instead of the full Cartesian product. Each
124 rule lists the membership-function index per input. ``None`` keeps the default
125 exhaustive rule set.
126 """
128 def __init__(
129 self,
130 *,
131 n_mfs: int = 3,
132 mf_type: str = "gaussian",
133 init: str | None = "grid",
134 overlap: float = 0.5,
135 margin: float = 0.10,
136 inputs_config: Mapping[Any, Any] | None = None,
137 random_state: int | None = None,
138 optimizer: str | BaseTrainer | type[BaseTrainer] | None = "hybrid",
139 optimizer_params: Mapping[str, Any] | None = None,
140 learning_rate: float | None = None,
141 epochs: int | None = None,
142 batch_size: int | None = None,
143 shuffle: bool | None = None,
144 verbose: bool = False,
145 loss: LossFunction | str | None = None,
146 rules: Sequence[Sequence[int]] | None = None,
147 ) -> None:
148 """Construct an :class:`ANFISRegressor` with the provided hyper-parameters.
150 Parameters
151 ----------
152 n_mfs : int, default=3
153 Default number of membership functions allocated to each input when
154 they are inferred from data.
155 mf_type : str, default="gaussian"
156 Membership function family used for automatically generated
157 membership functions. Supported names mirror the ones exported in
158 :mod:`anfis_toolbox.membership` (e.g. ``"gaussian"``,
159 ``"triangular"``, ``"bell"``).
160 init : {"grid", "fcm", "random", None}, default="grid"
161 Initialization strategy employed when synthesizing membership
162 functions from the training data. ``None`` falls back to
163 ``"grid"``.
164 overlap : float, default=0.5
165 Desired overlap between neighbouring membership functions during
166 automatic construction.
167 margin : float, default=0.10
168 Extra range added around the observed feature minima/maxima when
169 performing grid initialization.
170 inputs_config : Mapping, optional
171 Per-feature overrides for membership configuration. Keys may be
172 feature names (e.g. when ``X`` is a :class:`pandas.DataFrame`),
173 integer indices, or ``"x{i}"`` aliases. Values accept dictionaries
174 with membership keywords (e.g. ``"n_mfs"``, ``"mf_type"``,
175 ``"init"``), explicit membership function lists, or scalars for
176 simple overrides. ``None`` entries keep defaults.
177 random_state : int, optional
178 Seed propagated to stochastic components such as FCM-based
179 initialization and optimizers that rely on randomness.
180 optimizer : str | BaseTrainer | type[BaseTrainer] | None, default="hybrid"
181 Trainer identifier or instance used for fitting. String aliases are
182 looked up in :data:`TRAINER_REGISTRY`. ``None`` defaults to
183 ``"hybrid"``.
184 optimizer_params : Mapping, optional
185 Extra keyword arguments forwarded to the trainer constructor when a
186 string identifier or class is supplied.
187 learning_rate, epochs, batch_size, shuffle, verbose : optional
188 Convenience hyper-parameters that are injected into the selected
189 trainer when supported. ``shuffle`` accepts ``False`` to disable
190 randomisation.
191 loss : str | LossFunction, optional
192 Custom loss forwarded to trainers exposing a ``loss`` parameter.
193 ``None`` keeps the trainer default (typically mean squared error).
194 rules : Sequence[Sequence[int]] | None, optional
195 Optional explicit fuzzy rule definitions. Each rule lists the
196 membership index for every input. ``None`` uses the full Cartesian
197 product of configured membership functions.
198 """
199 self.n_mfs = int(n_mfs)
200 self.mf_type = str(mf_type)
201 self.init = None if init is None else str(init)
202 self.overlap = float(overlap)
203 self.margin = float(margin)
204 self.inputs_config: dict[Any, InputConfigValue] | None = (
205 dict(inputs_config) if inputs_config is not None else None
206 )
207 self.random_state = random_state
208 self.optimizer = optimizer
209 self.optimizer_params = dict(optimizer_params) if optimizer_params is not None else None
210 self.learning_rate = learning_rate
211 self.epochs = epochs
212 self.batch_size = batch_size
213 self.shuffle = shuffle
214 self.verbose = verbose
215 self.loss = loss
216 self.rules = None if rules is None else tuple(tuple(int(idx) for idx in rule) for rule in rules)
218 # Fitted attributes (initialised later)
219 self.model_: TSKANFIS | None = None
220 self.optimizer_: BaseTrainer | None = None
221 self.feature_names_in_: list[str] | None = None
222 self.n_features_in_: int | None = None
223 self.training_history_: TrainingHistory | None = None
224 self.input_specs_: list[NormalizedInputSpec] | None = None
225 self.rules_: list[tuple[int, ...]] | None = None
227 # ------------------------------------------------------------------
228 # Public API
229 # ------------------------------------------------------------------
230 def fit(
231 self,
232 X: npt.ArrayLike,
233 y: npt.ArrayLike,
234 *,
235 validation_data: tuple[np.ndarray, np.ndarray] | None = None,
236 validation_frequency: int = 1,
237 verbose: bool | None = None,
238 **fit_params: Any,
239 ) -> ANFISRegressor:
240 """Fit the ANFIS regressor on labelled data.
242 Parameters
243 ----------
244 X : array-like
245 Training inputs with shape ``(n_samples, n_features)``.
246 y : array-like
247 Target values aligned with ``X``. One-dimensional vectors are
248 accepted and reshaped internally.
249 validation_data : tuple[np.ndarray, np.ndarray], optional
250 Optional validation split supplied to the underlying trainer. Both
251 arrays must already be numeric and share the same row count.
252 validation_frequency : int, default=1
253 Frequency (in epochs) at which validation loss is evaluated when
254 ``validation_data`` is provided.
255 verbose : bool, optional
256 Override the estimator's ``verbose`` flag for this fit call. When
257 supplied, the value is stored on the estimator and forwarded to the
258 trainer configuration.
259 **fit_params : Any
260 Arbitrary keyword arguments forwarded to the trainer ``fit``
261 method.
263 Returns:
264 -------
265 ANFISRegressor
266 Reference to ``self`` for fluent-style chaining.
268 Raises:
269 ------
270 ValueError
271 If ``X`` and ``y`` contain a different number of samples.
272 ValueError
273 If validation frequency is less than one.
274 TypeError
275 If the configured trainer returns an object that is not a
276 ``dict``-like training history.
277 """
278 X_arr, feature_names = ensure_2d_array(X)
279 y_vec = ensure_vector(y)
280 if X_arr.shape[0] != y_vec.shape[0]:
281 raise ValueError("X and y must contain the same number of samples.")
283 self.feature_names_in_ = feature_names
284 self.n_features_in_ = X_arr.shape[1]
285 self.input_specs_ = self._resolve_input_specs(feature_names)
287 if verbose is not None:
288 self.verbose = bool(verbose)
290 _ensure_training_logging(self.verbose)
291 model = self._build_model(X_arr, feature_names)
292 self.model_ = model
293 trainer = self._instantiate_trainer()
294 self.optimizer_ = trainer
295 trainer_kwargs: dict[str, Any] = dict(fit_params)
296 if validation_data is not None:
297 trainer_kwargs.setdefault("validation_data", validation_data)
298 if validation_data is not None or validation_frequency != 1:
299 trainer_kwargs.setdefault("validation_frequency", validation_frequency)
301 history = trainer.fit(model, X_arr, y_vec, **trainer_kwargs)
302 if not isinstance(history, dict):
303 raise TypeError("Trainer.fit must return a TrainingHistory dictionary")
304 self.training_history_ = history
305 self.rules_ = model.rules
307 self._mark_fitted()
308 return self
310 def predict(self, X: npt.ArrayLike) -> np.ndarray:
311 """Predict regression targets for the provided samples.
313 Parameters
314 ----------
315 X : array-like
316 Samples to evaluate. Accepts one-dimensional arrays (interpreted as
317 a single sample) or matrices with shape ``(n_samples, n_features)``.
319 Returns:
320 -------
321 np.ndarray
322 Vector of predictions with shape ``(n_samples,)``.
324 Raises:
325 ------
326 RuntimeError
327 If the estimator has not been fitted yet.
328 ValueError
329 When the supplied samples do not match the fitted feature count.
330 """
331 check_is_fitted(self, attributes=["model_"])
332 X_arr = np.asarray(X, dtype=float)
333 if X_arr.ndim == 1:
334 X_arr = X_arr.reshape(1, -1)
335 else:
336 X_arr, _ = ensure_2d_array(X)
338 if self.n_features_in_ is None:
339 raise RuntimeError("Model must be fitted before calling predict.")
340 if X_arr.shape[1] != self.n_features_in_:
341 raise ValueError(f"Feature mismatch: expected {self.n_features_in_}, got {X_arr.shape[1]}.")
343 model = self.model_
344 if model is None:
345 raise RuntimeError("Model must be fitted before calling predict.")
346 preds = model.predict(X_arr)
347 return np.asarray(preds, dtype=float).reshape(-1)
349 def evaluate(
350 self,
351 X: npt.ArrayLike,
352 y: npt.ArrayLike,
353 *,
354 return_dict: bool = True,
355 print_results: bool = True,
356 ) -> Mapping[str, MetricValue] | None:
357 """Evaluate predictive performance on a dataset.
359 Parameters
360 ----------
361 X : array-like
362 Evaluation inputs with shape ``(n_samples, n_features)``.
363 y : array-like
364 Ground-truth targets aligned with ``X``.
365 return_dict : bool, default=True
366 When ``True``, return the computed metric dictionary. When
367 ``False``, only perform side effects (such as printing) and return
368 ``None``.
369 print_results : bool, default=True
370 Log a human-readable summary to stdout. Set to ``False`` to
371 suppress printing.
373 Returns:
374 -------
375 Mapping[str, MetricValue] | None
376 Regression metrics including mean squared error, root mean squared
377 error, mean absolute error, and :math:`R^2` when ``return_dict`` is
378 ``True``; otherwise ``None``.
380 Raises:
381 ------
382 RuntimeError
383 If called before ``fit``.
384 ValueError
385 When ``X`` and ``y`` disagree on the sample count.
386 """
387 check_is_fitted(self, attributes=["model_"])
388 X_arr, _ = ensure_2d_array(X)
389 y_vec = ensure_vector(y)
390 preds = self.predict(X_arr)
391 metrics: dict[str, MetricValue] = ANFISMetrics.regression_metrics(y_vec, preds)
392 if print_results:
394 def _is_effectively_nan(value: Any) -> bool:
395 if value is None:
396 return True
397 if isinstance(value, (float, np.floating)):
398 return bool(np.isnan(value))
399 if isinstance(value, (int, np.integer)):
400 return False
401 if isinstance(value, np.ndarray):
402 if value.size == 0:
403 return False
404 if np.issubdtype(value.dtype, np.number):
405 return bool(np.isnan(value.astype(float)).all())
406 return False
407 return False
409 print("ANFISRegressor evaluation:") # noqa: T201
410 for key, value in metrics.items():
411 if _is_effectively_nan(value):
412 continue
413 if isinstance(value, (float, np.floating)):
414 display_value = f"{float(value):.6f}"
415 print(f" {key}: {display_value}") # noqa: T201
416 elif isinstance(value, (int, np.integer)):
417 print(f" {key}: {int(value)}") # noqa: T201
418 elif isinstance(value, np.ndarray):
419 array_repr = np.array2string(value, precision=6, suppress_small=True)
420 if "\n" in array_repr:
421 indented = "\n ".join(array_repr.splitlines())
422 print(f" {key}:\n {indented}") # noqa: T201
423 else:
424 print(f" {key}: {array_repr}") # noqa: T201
425 else:
426 print(f" {key}: {value}") # noqa: T201
427 return metrics if return_dict else None
429 def get_rules(self) -> tuple[tuple[int, ...], ...]:
430 """Return the fuzzy rule index combinations used by the fitted model.
432 Returns:
433 -------
434 tuple[tuple[int, ...], ...]
435 Immutable tuple containing one tuple per fuzzy rule, where each
436 inner tuple lists the membership index chosen for each input.
438 Raises:
439 ------
440 RuntimeError
441 If invoked before the estimator is fitted.
442 """
443 check_is_fitted(self, attributes=["rules_"])
444 if not self.rules_:
445 return ()
446 return tuple(tuple(rule) for rule in self.rules_)
448 def save(self, filepath: str | Path) -> None:
449 """Serialize this estimator (and its fitted state) using ``pickle``."""
450 path = Path(filepath)
451 path.parent.mkdir(parents=True, exist_ok=True)
452 with path.open("wb") as stream:
453 pickle.dump(self, stream) # nosec B301
455 @classmethod
456 def load(cls, filepath: str | Path) -> ANFISRegressor:
457 """Load a pickled estimator from ``filepath`` and validate its type."""
458 path = Path(filepath)
459 with path.open("rb") as stream:
460 estimator = pickle.load(stream) # nosec B301
461 if not isinstance(estimator, cls):
462 raise TypeError(f"Expected pickled {cls.__name__} instance, got {type(estimator).__name__}.")
463 return estimator
465 def __repr__(self) -> str:
466 """Return a formatted representation summarising configuration and fitted artefacts."""
467 return format_estimator_repr(
468 type(self).__name__,
469 self._repr_config_pairs(),
470 self._repr_children_entries(),
471 )
473 def _more_tags(self) -> dict[str, Any]: # pragma: no cover - informational hook
474 return {
475 "estimator_type": "regressor",
476 "requires_y": True,
477 }
479 # ------------------------------------------------------------------
480 # Helpers
481 # ------------------------------------------------------------------
482 def _resolve_input_specs(self, feature_names: list[str]) -> list[NormalizedInputSpec]:
483 resolved: list[NormalizedInputSpec] = []
484 for idx, name in enumerate(feature_names):
485 spec = self._fetch_input_config(name, idx)
486 resolved.append(self._normalize_input_spec(spec))
487 return resolved
489 # ------------------------------------------------------------------
490 # Representation helpers
491 # ------------------------------------------------------------------
492 def _repr_config_pairs(self) -> list[tuple[str, Any]]:
493 optimizer_label = self._describe_optimizer_config(self.optimizer)
494 pairs: list[tuple[str, Any]] = [
495 ("n_mfs", self.n_mfs),
496 ("mf_type", self.mf_type),
497 ("init", self.init),
498 ("overlap", self.overlap),
499 ("margin", self.margin),
500 ("random_state", self.random_state),
501 ("optimizer", optimizer_label),
502 ("learning_rate", self.learning_rate),
503 ("epochs", self.epochs),
504 ("batch_size", self.batch_size),
505 ("shuffle", self.shuffle),
506 ("loss", self.loss),
507 ]
508 if self.rules is not None:
509 pairs.append(("rules", f"preset:{len(self.rules)}"))
510 if self.optimizer_params:
511 pairs.append(("optimizer_params", self.optimizer_params))
512 return pairs
514 def _repr_children_entries(self) -> list[tuple[str, str]]:
515 if not getattr(self, "is_fitted_", False):
516 return []
518 children: list[tuple[str, str]] = []
519 model = getattr(self, "model_", None)
520 if model is not None:
521 children.append(("model_", self._summarize_model(model)))
523 optimizer = getattr(self, "optimizer_", None)
524 if optimizer is not None:
525 children.append(("optimizer_", self._summarize_optimizer(optimizer)))
527 history = getattr(self, "training_history_", None)
528 if isinstance(history, Mapping) and history:
529 children.append(("training_history_", self._summarize_history(history)))
531 learned_rules = getattr(self, "rules_", None)
532 if learned_rules is not None:
533 children.append(("rules_", f"{len(learned_rules)} learned"))
535 feature_names = getattr(self, "feature_names_in_", None)
536 if feature_names is not None:
537 children.append(("feature_names_in_", ", ".join(feature_names)))
539 return children
541 @staticmethod
542 def _describe_optimizer_config(optimizer: Any) -> Any:
543 if optimizer is None:
544 return None
545 if isinstance(optimizer, str):
546 return optimizer
547 if inspect.isclass(optimizer):
548 return optimizer.__name__
549 if isinstance(optimizer, BaseTrainer):
550 return type(optimizer).__name__
551 return repr(optimizer)
553 def _summarize_model(self, model: Any) -> str:
554 name = type(model).__name__
555 parts = [name]
556 n_inputs = getattr(model, "n_inputs", None)
557 n_rules = getattr(model, "n_rules", None)
558 if n_inputs is not None:
559 parts.append(f"n_inputs={n_inputs}")
560 if n_rules is not None:
561 parts.append(f"n_rules={n_rules}")
562 input_names = getattr(model, "input_names", None)
563 if input_names:
564 parts.append(f"inputs={list(input_names)}")
565 mf_map = getattr(model, "membership_functions", None)
566 if isinstance(mf_map, Mapping) and mf_map:
567 counts = [len(mf_map[name]) for name in getattr(model, "input_names", mf_map.keys())]
568 parts.append(f"mfs_per_input={counts}")
569 return ", ".join(parts)
571 def _summarize_optimizer(self, optimizer: BaseTrainer) -> str:
572 name = type(optimizer).__name__
573 fields: list[str] = []
574 for attr in ("learning_rate", "epochs", "batch_size", "shuffle", "verbose", "loss"):
575 if hasattr(optimizer, attr):
576 value = getattr(optimizer, attr)
577 if value is not None:
578 fields.append(f"{attr}={value!r}")
579 if hasattr(optimizer, "__dict__") and not fields:
580 # Fall back to repr if no recognised fields were populated
581 return repr(optimizer)
582 return f"{name}({', '.join(fields)})" if fields else name
584 @staticmethod
585 def _summarize_history(history: Mapping[str, Any]) -> str:
586 segments: list[str] = []
587 for key in ("train", "val", "validation", "metrics"):
588 if key in history and isinstance(history[key], Sequence):
589 series = history[key]
590 length = len(series)
591 if length == 0:
592 segments.append(f"{key}=0")
593 else:
594 tail = series[-1]
595 if isinstance(tail, (float, np.floating)):
596 segments.append(f"{key}={length} (last={float(tail):.4f})")
597 else:
598 segments.append(f"{key}={length}")
599 return ", ".join(segments) if segments else "{}"
601 def _fetch_input_config(self, name: str, index: int) -> InputConfigValue:
602 if self.inputs_config is None:
603 return None
604 spec = self.inputs_config.get(name)
605 if spec is not None:
606 return spec
607 spec = self.inputs_config.get(index)
608 if spec is not None:
609 return spec
610 alt_key = f"x{index + 1}"
611 return self.inputs_config.get(alt_key)
613 def _normalize_input_spec(self, spec: InputConfigValue) -> NormalizedInputSpec:
614 config: NormalizedInputSpec = {
615 "n_mfs": self.n_mfs,
616 "mf_type": self.mf_type,
617 "init": self.init,
618 "overlap": self.overlap,
619 "margin": self.margin,
620 "range": None,
621 "membership_functions": None,
622 }
623 if spec is None:
624 return config
625 if isinstance(spec, (list, tuple)) and all(isinstance(mf, MembershipFunction) for mf in spec):
626 config["membership_functions"] = list(spec)
627 return config
628 if isinstance(spec, MembershipFunction):
629 config["membership_functions"] = [spec]
630 return config
631 if isinstance(spec, Mapping):
632 mapping = dict(spec)
633 if "mfs" in mapping and "membership_functions" not in mapping:
634 mapping = {**mapping, "membership_functions": mapping["mfs"]}
635 for key in ("n_mfs", "mf_type", "init", "overlap", "margin", "range", "membership_functions"):
636 if key in mapping and (mapping[key] is not None or key == "init"):
637 config[key] = mapping[key]
638 return config
639 if isinstance(spec, str):
640 config["mf_type"] = spec
641 return config
642 if isinstance(spec, int):
643 config["n_mfs"] = int(spec)
644 return config
645 raise TypeError(f"Unsupported input configuration type: {type(spec)!r}")
647 def _build_model(self, X: np.ndarray, feature_names: list[str]) -> TSKANFIS:
648 builder = ANFISBuilder()
649 if self.input_specs_ is None:
650 raise RuntimeError("Input specifications must be resolved before building the model.")
651 for idx, name in enumerate(feature_names):
652 column = X[:, idx]
653 spec = self.input_specs_[idx]
654 mf_list = spec.get("membership_functions")
655 range_override = spec.get("range")
656 if mf_list is not None:
657 builder.input_mfs[name] = [cast(MembershipFunction, mf) for mf in mf_list]
658 if range_override is not None:
659 range_tuple = tuple(float(v) for v in range_override)
660 if len(range_tuple) != 2:
661 raise ValueError("range overrides must contain exactly two values")
662 builder.input_ranges[name] = (range_tuple[0], range_tuple[1])
663 else:
664 builder.input_ranges[name] = (float(np.min(column)), float(np.max(column)))
665 continue
666 if range_override is not None:
667 range_tuple = tuple(float(v) for v in range_override)
668 if len(range_tuple) != 2:
669 raise ValueError("range overrides must contain exactly two values")
670 rmin, rmax = range_tuple
671 builder.add_input(
672 name,
673 float(rmin),
674 float(rmax),
675 int(spec["n_mfs"]),
676 str(spec["mf_type"]),
677 overlap=float(spec["overlap"]),
678 )
679 else:
680 init_strategy = spec.get("init")
681 init_arg = None if init_strategy is None else str(init_strategy)
682 builder.add_input_from_data(
683 name,
684 column,
685 n_mfs=int(spec["n_mfs"]),
686 mf_type=str(spec["mf_type"]),
687 overlap=float(spec["overlap"]),
688 margin=float(spec["margin"]),
689 init=init_arg,
690 random_state=self.random_state,
691 )
692 builder.set_rules(self.rules)
693 return builder.build()
695 def _instantiate_trainer(self) -> BaseTrainer:
696 optimizer = self.optimizer if self.optimizer is not None else "hybrid"
697 if isinstance(optimizer, BaseTrainer):
698 trainer = deepcopy(optimizer)
699 self._apply_runtime_overrides(trainer)
700 return trainer
701 if inspect.isclass(optimizer) and issubclass(optimizer, BaseTrainer):
702 params = self._collect_trainer_params(optimizer)
703 return optimizer(**params)
704 if isinstance(optimizer, str):
705 key = optimizer.lower()
706 if key not in TRAINER_REGISTRY:
707 supported = ", ".join(sorted(TRAINER_REGISTRY.keys()))
708 raise ValueError(f"Unknown optimizer '{optimizer}'. Supported: {supported}")
709 trainer_cls = TRAINER_REGISTRY[key]
710 params = self._collect_trainer_params(trainer_cls)
711 return trainer_cls(**params)
712 raise TypeError("optimizer must be a string identifier, BaseTrainer instance, or BaseTrainer subclass")
714 def _collect_trainer_params(self, trainer_cls: type[BaseTrainer]) -> dict[str, Any]:
715 params: dict[str, Any] = {}
716 if self.optimizer_params is not None:
717 params.update(self.optimizer_params)
719 overrides = {
720 "learning_rate": self.learning_rate,
721 "epochs": self.epochs,
722 "batch_size": self.batch_size,
723 "shuffle": self.shuffle,
724 "verbose": self.verbose,
725 "loss": self.loss,
726 }
727 for key, value in overrides.items():
728 if value is not None and key not in params:
729 params[key] = value
730 # Ensure boolean defaults propagate when value could be False
731 if self.shuffle is not None:
732 params.setdefault("shuffle", self.shuffle)
733 params.setdefault("verbose", self.verbose)
735 sig = inspect.signature(trainer_cls)
736 filtered: dict[str, Any] = {}
737 for name in sig.parameters:
738 if name == "self":
739 continue
740 if name in params:
741 filtered[name] = params[name]
742 return filtered
744 def _apply_runtime_overrides(self, trainer: BaseTrainer) -> None:
745 for attr, value in (
746 ("learning_rate", self.learning_rate),
747 ("epochs", self.epochs),
748 ("batch_size", self.batch_size),
749 ("shuffle", self.shuffle),
750 ("verbose", self.verbose),
751 ("loss", self.loss),
752 ):
753 if value is not None and hasattr(trainer, attr):
754 setattr(trainer, attr, value)
755 if hasattr(trainer, "verbose") and self.verbose is not None:
756 trainer.verbose = self.verbose