Coverage for anfis_toolbox / regressor.py: 100%

348 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-05 18:47 -0300

1"""High-level regression estimator facade for ANFIS. 

2 

3The :class:`ANFISRegressor` provides a scikit-learn style interface that wires 

4up membership-function generation, model construction, and optimizer selection 

5at instantiation time. It reuses the low-level :mod:`anfis_toolbox` components 

6under the hood without introducing an external dependency on scikit-learn. 

7""" 

8 

9from __future__ import annotations 

10 

11import inspect 

12import logging 

13import pickle # nosec B403 

14from collections.abc import Mapping, Sequence 

15from copy import deepcopy 

16from pathlib import Path 

17from typing import Any, TypeAlias, cast 

18 

19import numpy as np 

20import numpy.typing as npt 

21 

22from .builders import ANFISBuilder 

23from .estimator_utils import ( 

24 BaseEstimatorLike, 

25 FittedMixin, 

26 RegressorMixinLike, 

27 check_is_fitted, 

28 ensure_2d_array, 

29 ensure_vector, 

30 format_estimator_repr, 

31) 

32from .logging_config import enable_training_logs 

33from .losses import LossFunction 

34from .membership import MembershipFunction 

35from .metrics import ANFISMetrics, MetricValue 

36from .model import TSKANFIS, TrainingHistory 

37from .optim import ( 

38 AdamTrainer, 

39 BaseTrainer, 

40 HybridAdamTrainer, 

41 HybridTrainer, 

42 PSOTrainer, 

43 RMSPropTrainer, 

44 SGDTrainer, 

45) 

46 

47InputConfigValue: TypeAlias = Mapping[str, Any] | Sequence[Any] | MembershipFunction | str | int | None 

48NormalizedInputSpec: TypeAlias = dict[str, Any] 

49 

50TRAINER_REGISTRY: dict[str, type[BaseTrainer]] = { 

51 "hybrid": HybridTrainer, 

52 "hybrid_adam": HybridAdamTrainer, 

53 "sgd": SGDTrainer, 

54 "adam": AdamTrainer, 

55 "rmsprop": RMSPropTrainer, 

56 "pso": PSOTrainer, 

57} 

58 

59 

60def _ensure_training_logging(verbose: bool) -> None: 

61 if not verbose: 

62 return 

63 logger = logging.getLogger("anfis_toolbox") 

64 if logger.handlers: 

65 return 

66 enable_training_logs() 

67 

68 

69class ANFISRegressor(BaseEstimatorLike, FittedMixin, RegressorMixinLike): 

70 """Adaptive Neuro-Fuzzy regressor with a scikit-learn style API. 

71 

72 The estimator manages membership-function synthesis, rule construction, and 

73 trainer selection so you can focus on calling :meth:`fit`, :meth:`predict`, 

74 and :meth:`evaluate` with familiar NumPy-like data structures. 

75 

76 Examples: 

77 -------- 

78 >>> reg = ANFISRegressor() 

79 >>> reg.fit(X, y) 

80 ANFISRegressor(...) 

81 >>> reg.predict(X[:1]) 

82 array([...]) 

83 

84 Parameters 

85 ---------- 

86 n_mfs : int, default=3 

87 Default number of membership functions per input. 

88 mf_type : str, default="gaussian" 

89 Default membership function family used for automatically generated 

90 membership functions. Supported values include ``"gaussian"``, 

91 ``"triangular"``, ``"bell"``, and other names exposed by the 

92 membership catalogue. 

93 init : {"grid", "fcm", "random", None}, default="grid" 

94 Strategy used when inferring membership functions from data. ``None`` 

95 falls back to ``"grid"``. 

96 overlap : float, default=0.5 

97 Controls overlap when generating membership functions automatically. 

98 margin : float, default=0.10 

99 Margin added around observed data ranges during automatic 

100 initialization. 

101 inputs_config : Mapping, optional 

102 Per-input overrides. Keys may be feature names (when ``X`` is a 

103 :class:`pandas.DataFrame`) or integer indices. Values may be: 

104 

105 * ``dict`` with keys among ``{"n_mfs", "mf_type", "init", "overlap", 

106 "margin", "range", "membership_functions", "mfs"}``. 

107 * A list/tuple of :class:`MembershipFunction` instances for full control. 

108 * ``None`` for defaults. 

109 random_state : int, optional 

110 Random state forwarded to FCM-based initialization and any stochastic 

111 optimizers. 

112 optimizer : str, BaseTrainer, type[BaseTrainer], or None, default="hybrid" 

113 Trainer identifier or instance used for fitting. Strings map to entries 

114 in :data:`TRAINER_REGISTRY`. ``None`` defaults to "hybrid". 

115 optimizer_params : Mapping, optional 

116 Additional keyword arguments forwarded to the trainer constructor. 

117 learning_rate, epochs, batch_size, shuffle, verbose : optional scalars 

118 Common trainer hyper-parameters provided for convenience. When the 

119 selected trainer supports the parameter it is included automatically. 

120 loss : str or LossFunction, optional 

121 Custom loss forwarded to trainers that expose a ``loss`` parameter. 

122 rules : Sequence[Sequence[int]] | None, optional 

123 Explicit fuzzy rule indices to use instead of the full Cartesian product. Each 

124 rule lists the membership-function index per input. ``None`` keeps the default 

125 exhaustive rule set. 

126 """ 

127 

128 def __init__( 

129 self, 

130 *, 

131 n_mfs: int = 3, 

132 mf_type: str = "gaussian", 

133 init: str | None = "grid", 

134 overlap: float = 0.5, 

135 margin: float = 0.10, 

136 inputs_config: Mapping[Any, Any] | None = None, 

137 random_state: int | None = None, 

138 optimizer: str | BaseTrainer | type[BaseTrainer] | None = "hybrid", 

139 optimizer_params: Mapping[str, Any] | None = None, 

140 learning_rate: float | None = None, 

141 epochs: int | None = None, 

142 batch_size: int | None = None, 

143 shuffle: bool | None = None, 

144 verbose: bool = False, 

145 loss: LossFunction | str | None = None, 

146 rules: Sequence[Sequence[int]] | None = None, 

147 ) -> None: 

148 """Construct an :class:`ANFISRegressor` with the provided hyper-parameters. 

149 

150 Parameters 

151 ---------- 

152 n_mfs : int, default=3 

153 Default number of membership functions allocated to each input when 

154 they are inferred from data. 

155 mf_type : str, default="gaussian" 

156 Membership function family used for automatically generated 

157 membership functions. Supported names mirror the ones exported in 

158 :mod:`anfis_toolbox.membership` (e.g. ``"gaussian"``, 

159 ``"triangular"``, ``"bell"``). 

160 init : {"grid", "fcm", "random", None}, default="grid" 

161 Initialization strategy employed when synthesizing membership 

162 functions from the training data. ``None`` falls back to 

163 ``"grid"``. 

164 overlap : float, default=0.5 

165 Desired overlap between neighbouring membership functions during 

166 automatic construction. 

167 margin : float, default=0.10 

168 Extra range added around the observed feature minima/maxima when 

169 performing grid initialization. 

170 inputs_config : Mapping, optional 

171 Per-feature overrides for membership configuration. Keys may be 

172 feature names (e.g. when ``X`` is a :class:`pandas.DataFrame`), 

173 integer indices, or ``"x{i}"`` aliases. Values accept dictionaries 

174 with membership keywords (e.g. ``"n_mfs"``, ``"mf_type"``, 

175 ``"init"``), explicit membership function lists, or scalars for 

176 simple overrides. ``None`` entries keep defaults. 

177 random_state : int, optional 

178 Seed propagated to stochastic components such as FCM-based 

179 initialization and optimizers that rely on randomness. 

180 optimizer : str | BaseTrainer | type[BaseTrainer] | None, default="hybrid" 

181 Trainer identifier or instance used for fitting. String aliases are 

182 looked up in :data:`TRAINER_REGISTRY`. ``None`` defaults to 

183 ``"hybrid"``. 

184 optimizer_params : Mapping, optional 

185 Extra keyword arguments forwarded to the trainer constructor when a 

186 string identifier or class is supplied. 

187 learning_rate, epochs, batch_size, shuffle, verbose : optional 

188 Convenience hyper-parameters that are injected into the selected 

189 trainer when supported. ``shuffle`` accepts ``False`` to disable 

190 randomisation. 

191 loss : str | LossFunction, optional 

192 Custom loss forwarded to trainers exposing a ``loss`` parameter. 

193 ``None`` keeps the trainer default (typically mean squared error). 

194 rules : Sequence[Sequence[int]] | None, optional 

195 Optional explicit fuzzy rule definitions. Each rule lists the 

196 membership index for every input. ``None`` uses the full Cartesian 

197 product of configured membership functions. 

198 """ 

199 self.n_mfs = int(n_mfs) 

200 self.mf_type = str(mf_type) 

201 self.init = None if init is None else str(init) 

202 self.overlap = float(overlap) 

203 self.margin = float(margin) 

204 self.inputs_config: dict[Any, InputConfigValue] | None = ( 

205 dict(inputs_config) if inputs_config is not None else None 

206 ) 

207 self.random_state = random_state 

208 self.optimizer = optimizer 

209 self.optimizer_params = dict(optimizer_params) if optimizer_params is not None else None 

210 self.learning_rate = learning_rate 

211 self.epochs = epochs 

212 self.batch_size = batch_size 

213 self.shuffle = shuffle 

214 self.verbose = verbose 

215 self.loss = loss 

216 self.rules = None if rules is None else tuple(tuple(int(idx) for idx in rule) for rule in rules) 

217 

218 # Fitted attributes (initialised later) 

219 self.model_: TSKANFIS | None = None 

220 self.optimizer_: BaseTrainer | None = None 

221 self.feature_names_in_: list[str] | None = None 

222 self.n_features_in_: int | None = None 

223 self.training_history_: TrainingHistory | None = None 

224 self.input_specs_: list[NormalizedInputSpec] | None = None 

225 self.rules_: list[tuple[int, ...]] | None = None 

226 

227 # ------------------------------------------------------------------ 

228 # Public API 

229 # ------------------------------------------------------------------ 

230 def fit( 

231 self, 

232 X: npt.ArrayLike, 

233 y: npt.ArrayLike, 

234 *, 

235 validation_data: tuple[np.ndarray, np.ndarray] | None = None, 

236 validation_frequency: int = 1, 

237 verbose: bool | None = None, 

238 **fit_params: Any, 

239 ) -> ANFISRegressor: 

240 """Fit the ANFIS regressor on labelled data. 

241 

242 Parameters 

243 ---------- 

244 X : array-like 

245 Training inputs with shape ``(n_samples, n_features)``. 

246 y : array-like 

247 Target values aligned with ``X``. One-dimensional vectors are 

248 accepted and reshaped internally. 

249 validation_data : tuple[np.ndarray, np.ndarray], optional 

250 Optional validation split supplied to the underlying trainer. Both 

251 arrays must already be numeric and share the same row count. 

252 validation_frequency : int, default=1 

253 Frequency (in epochs) at which validation loss is evaluated when 

254 ``validation_data`` is provided. 

255 verbose : bool, optional 

256 Override the estimator's ``verbose`` flag for this fit call. When 

257 supplied, the value is stored on the estimator and forwarded to the 

258 trainer configuration. 

259 **fit_params : Any 

260 Arbitrary keyword arguments forwarded to the trainer ``fit`` 

261 method. 

262 

263 Returns: 

264 ------- 

265 ANFISRegressor 

266 Reference to ``self`` for fluent-style chaining. 

267 

268 Raises: 

269 ------ 

270 ValueError 

271 If ``X`` and ``y`` contain a different number of samples. 

272 ValueError 

273 If validation frequency is less than one. 

274 TypeError 

275 If the configured trainer returns an object that is not a 

276 ``dict``-like training history. 

277 """ 

278 X_arr, feature_names = ensure_2d_array(X) 

279 y_vec = ensure_vector(y) 

280 if X_arr.shape[0] != y_vec.shape[0]: 

281 raise ValueError("X and y must contain the same number of samples.") 

282 

283 self.feature_names_in_ = feature_names 

284 self.n_features_in_ = X_arr.shape[1] 

285 self.input_specs_ = self._resolve_input_specs(feature_names) 

286 

287 if verbose is not None: 

288 self.verbose = bool(verbose) 

289 

290 _ensure_training_logging(self.verbose) 

291 model = self._build_model(X_arr, feature_names) 

292 self.model_ = model 

293 trainer = self._instantiate_trainer() 

294 self.optimizer_ = trainer 

295 trainer_kwargs: dict[str, Any] = dict(fit_params) 

296 if validation_data is not None: 

297 trainer_kwargs.setdefault("validation_data", validation_data) 

298 if validation_data is not None or validation_frequency != 1: 

299 trainer_kwargs.setdefault("validation_frequency", validation_frequency) 

300 

301 history = trainer.fit(model, X_arr, y_vec, **trainer_kwargs) 

302 if not isinstance(history, dict): 

303 raise TypeError("Trainer.fit must return a TrainingHistory dictionary") 

304 self.training_history_ = history 

305 self.rules_ = model.rules 

306 

307 self._mark_fitted() 

308 return self 

309 

310 def predict(self, X: npt.ArrayLike) -> np.ndarray: 

311 """Predict regression targets for the provided samples. 

312 

313 Parameters 

314 ---------- 

315 X : array-like 

316 Samples to evaluate. Accepts one-dimensional arrays (interpreted as 

317 a single sample) or matrices with shape ``(n_samples, n_features)``. 

318 

319 Returns: 

320 ------- 

321 np.ndarray 

322 Vector of predictions with shape ``(n_samples,)``. 

323 

324 Raises: 

325 ------ 

326 RuntimeError 

327 If the estimator has not been fitted yet. 

328 ValueError 

329 When the supplied samples do not match the fitted feature count. 

330 """ 

331 check_is_fitted(self, attributes=["model_"]) 

332 X_arr = np.asarray(X, dtype=float) 

333 if X_arr.ndim == 1: 

334 X_arr = X_arr.reshape(1, -1) 

335 else: 

336 X_arr, _ = ensure_2d_array(X) 

337 

338 if self.n_features_in_ is None: 

339 raise RuntimeError("Model must be fitted before calling predict.") 

340 if X_arr.shape[1] != self.n_features_in_: 

341 raise ValueError(f"Feature mismatch: expected {self.n_features_in_}, got {X_arr.shape[1]}.") 

342 

343 model = self.model_ 

344 if model is None: 

345 raise RuntimeError("Model must be fitted before calling predict.") 

346 preds = model.predict(X_arr) 

347 return np.asarray(preds, dtype=float).reshape(-1) 

348 

349 def evaluate( 

350 self, 

351 X: npt.ArrayLike, 

352 y: npt.ArrayLike, 

353 *, 

354 return_dict: bool = True, 

355 print_results: bool = True, 

356 ) -> Mapping[str, MetricValue] | None: 

357 """Evaluate predictive performance on a dataset. 

358 

359 Parameters 

360 ---------- 

361 X : array-like 

362 Evaluation inputs with shape ``(n_samples, n_features)``. 

363 y : array-like 

364 Ground-truth targets aligned with ``X``. 

365 return_dict : bool, default=True 

366 When ``True``, return the computed metric dictionary. When 

367 ``False``, only perform side effects (such as printing) and return 

368 ``None``. 

369 print_results : bool, default=True 

370 Log a human-readable summary to stdout. Set to ``False`` to 

371 suppress printing. 

372 

373 Returns: 

374 ------- 

375 Mapping[str, MetricValue] | None 

376 Regression metrics including mean squared error, root mean squared 

377 error, mean absolute error, and :math:`R^2` when ``return_dict`` is 

378 ``True``; otherwise ``None``. 

379 

380 Raises: 

381 ------ 

382 RuntimeError 

383 If called before ``fit``. 

384 ValueError 

385 When ``X`` and ``y`` disagree on the sample count. 

386 """ 

387 check_is_fitted(self, attributes=["model_"]) 

388 X_arr, _ = ensure_2d_array(X) 

389 y_vec = ensure_vector(y) 

390 preds = self.predict(X_arr) 

391 metrics: dict[str, MetricValue] = ANFISMetrics.regression_metrics(y_vec, preds) 

392 if print_results: 

393 

394 def _is_effectively_nan(value: Any) -> bool: 

395 if value is None: 

396 return True 

397 if isinstance(value, (float, np.floating)): 

398 return bool(np.isnan(value)) 

399 if isinstance(value, (int, np.integer)): 

400 return False 

401 if isinstance(value, np.ndarray): 

402 if value.size == 0: 

403 return False 

404 if np.issubdtype(value.dtype, np.number): 

405 return bool(np.isnan(value.astype(float)).all()) 

406 return False 

407 return False 

408 

409 print("ANFISRegressor evaluation:") # noqa: T201 

410 for key, value in metrics.items(): 

411 if _is_effectively_nan(value): 

412 continue 

413 if isinstance(value, (float, np.floating)): 

414 display_value = f"{float(value):.6f}" 

415 print(f" {key}: {display_value}") # noqa: T201 

416 elif isinstance(value, (int, np.integer)): 

417 print(f" {key}: {int(value)}") # noqa: T201 

418 elif isinstance(value, np.ndarray): 

419 array_repr = np.array2string(value, precision=6, suppress_small=True) 

420 if "\n" in array_repr: 

421 indented = "\n ".join(array_repr.splitlines()) 

422 print(f" {key}:\n {indented}") # noqa: T201 

423 else: 

424 print(f" {key}: {array_repr}") # noqa: T201 

425 else: 

426 print(f" {key}: {value}") # noqa: T201 

427 return metrics if return_dict else None 

428 

429 def get_rules(self) -> tuple[tuple[int, ...], ...]: 

430 """Return the fuzzy rule index combinations used by the fitted model. 

431 

432 Returns: 

433 ------- 

434 tuple[tuple[int, ...], ...] 

435 Immutable tuple containing one tuple per fuzzy rule, where each 

436 inner tuple lists the membership index chosen for each input. 

437 

438 Raises: 

439 ------ 

440 RuntimeError 

441 If invoked before the estimator is fitted. 

442 """ 

443 check_is_fitted(self, attributes=["rules_"]) 

444 if not self.rules_: 

445 return () 

446 return tuple(tuple(rule) for rule in self.rules_) 

447 

448 def save(self, filepath: str | Path) -> None: 

449 """Serialize this estimator (and its fitted state) using ``pickle``.""" 

450 path = Path(filepath) 

451 path.parent.mkdir(parents=True, exist_ok=True) 

452 with path.open("wb") as stream: 

453 pickle.dump(self, stream) # nosec B301 

454 

455 @classmethod 

456 def load(cls, filepath: str | Path) -> ANFISRegressor: 

457 """Load a pickled estimator from ``filepath`` and validate its type.""" 

458 path = Path(filepath) 

459 with path.open("rb") as stream: 

460 estimator = pickle.load(stream) # nosec B301 

461 if not isinstance(estimator, cls): 

462 raise TypeError(f"Expected pickled {cls.__name__} instance, got {type(estimator).__name__}.") 

463 return estimator 

464 

465 def __repr__(self) -> str: 

466 """Return a formatted representation summarising configuration and fitted artefacts.""" 

467 return format_estimator_repr( 

468 type(self).__name__, 

469 self._repr_config_pairs(), 

470 self._repr_children_entries(), 

471 ) 

472 

473 def _more_tags(self) -> dict[str, Any]: # pragma: no cover - informational hook 

474 return { 

475 "estimator_type": "regressor", 

476 "requires_y": True, 

477 } 

478 

479 # ------------------------------------------------------------------ 

480 # Helpers 

481 # ------------------------------------------------------------------ 

482 def _resolve_input_specs(self, feature_names: list[str]) -> list[NormalizedInputSpec]: 

483 resolved: list[NormalizedInputSpec] = [] 

484 for idx, name in enumerate(feature_names): 

485 spec = self._fetch_input_config(name, idx) 

486 resolved.append(self._normalize_input_spec(spec)) 

487 return resolved 

488 

489 # ------------------------------------------------------------------ 

490 # Representation helpers 

491 # ------------------------------------------------------------------ 

492 def _repr_config_pairs(self) -> list[tuple[str, Any]]: 

493 optimizer_label = self._describe_optimizer_config(self.optimizer) 

494 pairs: list[tuple[str, Any]] = [ 

495 ("n_mfs", self.n_mfs), 

496 ("mf_type", self.mf_type), 

497 ("init", self.init), 

498 ("overlap", self.overlap), 

499 ("margin", self.margin), 

500 ("random_state", self.random_state), 

501 ("optimizer", optimizer_label), 

502 ("learning_rate", self.learning_rate), 

503 ("epochs", self.epochs), 

504 ("batch_size", self.batch_size), 

505 ("shuffle", self.shuffle), 

506 ("loss", self.loss), 

507 ] 

508 if self.rules is not None: 

509 pairs.append(("rules", f"preset:{len(self.rules)}")) 

510 if self.optimizer_params: 

511 pairs.append(("optimizer_params", self.optimizer_params)) 

512 return pairs 

513 

514 def _repr_children_entries(self) -> list[tuple[str, str]]: 

515 if not getattr(self, "is_fitted_", False): 

516 return [] 

517 

518 children: list[tuple[str, str]] = [] 

519 model = getattr(self, "model_", None) 

520 if model is not None: 

521 children.append(("model_", self._summarize_model(model))) 

522 

523 optimizer = getattr(self, "optimizer_", None) 

524 if optimizer is not None: 

525 children.append(("optimizer_", self._summarize_optimizer(optimizer))) 

526 

527 history = getattr(self, "training_history_", None) 

528 if isinstance(history, Mapping) and history: 

529 children.append(("training_history_", self._summarize_history(history))) 

530 

531 learned_rules = getattr(self, "rules_", None) 

532 if learned_rules is not None: 

533 children.append(("rules_", f"{len(learned_rules)} learned")) 

534 

535 feature_names = getattr(self, "feature_names_in_", None) 

536 if feature_names is not None: 

537 children.append(("feature_names_in_", ", ".join(feature_names))) 

538 

539 return children 

540 

541 @staticmethod 

542 def _describe_optimizer_config(optimizer: Any) -> Any: 

543 if optimizer is None: 

544 return None 

545 if isinstance(optimizer, str): 

546 return optimizer 

547 if inspect.isclass(optimizer): 

548 return optimizer.__name__ 

549 if isinstance(optimizer, BaseTrainer): 

550 return type(optimizer).__name__ 

551 return repr(optimizer) 

552 

553 def _summarize_model(self, model: Any) -> str: 

554 name = type(model).__name__ 

555 parts = [name] 

556 n_inputs = getattr(model, "n_inputs", None) 

557 n_rules = getattr(model, "n_rules", None) 

558 if n_inputs is not None: 

559 parts.append(f"n_inputs={n_inputs}") 

560 if n_rules is not None: 

561 parts.append(f"n_rules={n_rules}") 

562 input_names = getattr(model, "input_names", None) 

563 if input_names: 

564 parts.append(f"inputs={list(input_names)}") 

565 mf_map = getattr(model, "membership_functions", None) 

566 if isinstance(mf_map, Mapping) and mf_map: 

567 counts = [len(mf_map[name]) for name in getattr(model, "input_names", mf_map.keys())] 

568 parts.append(f"mfs_per_input={counts}") 

569 return ", ".join(parts) 

570 

571 def _summarize_optimizer(self, optimizer: BaseTrainer) -> str: 

572 name = type(optimizer).__name__ 

573 fields: list[str] = [] 

574 for attr in ("learning_rate", "epochs", "batch_size", "shuffle", "verbose", "loss"): 

575 if hasattr(optimizer, attr): 

576 value = getattr(optimizer, attr) 

577 if value is not None: 

578 fields.append(f"{attr}={value!r}") 

579 if hasattr(optimizer, "__dict__") and not fields: 

580 # Fall back to repr if no recognised fields were populated 

581 return repr(optimizer) 

582 return f"{name}({', '.join(fields)})" if fields else name 

583 

584 @staticmethod 

585 def _summarize_history(history: Mapping[str, Any]) -> str: 

586 segments: list[str] = [] 

587 for key in ("train", "val", "validation", "metrics"): 

588 if key in history and isinstance(history[key], Sequence): 

589 series = history[key] 

590 length = len(series) 

591 if length == 0: 

592 segments.append(f"{key}=0") 

593 else: 

594 tail = series[-1] 

595 if isinstance(tail, (float, np.floating)): 

596 segments.append(f"{key}={length} (last={float(tail):.4f})") 

597 else: 

598 segments.append(f"{key}={length}") 

599 return ", ".join(segments) if segments else "{}" 

600 

601 def _fetch_input_config(self, name: str, index: int) -> InputConfigValue: 

602 if self.inputs_config is None: 

603 return None 

604 spec = self.inputs_config.get(name) 

605 if spec is not None: 

606 return spec 

607 spec = self.inputs_config.get(index) 

608 if spec is not None: 

609 return spec 

610 alt_key = f"x{index + 1}" 

611 return self.inputs_config.get(alt_key) 

612 

613 def _normalize_input_spec(self, spec: InputConfigValue) -> NormalizedInputSpec: 

614 config: NormalizedInputSpec = { 

615 "n_mfs": self.n_mfs, 

616 "mf_type": self.mf_type, 

617 "init": self.init, 

618 "overlap": self.overlap, 

619 "margin": self.margin, 

620 "range": None, 

621 "membership_functions": None, 

622 } 

623 if spec is None: 

624 return config 

625 if isinstance(spec, (list, tuple)) and all(isinstance(mf, MembershipFunction) for mf in spec): 

626 config["membership_functions"] = list(spec) 

627 return config 

628 if isinstance(spec, MembershipFunction): 

629 config["membership_functions"] = [spec] 

630 return config 

631 if isinstance(spec, Mapping): 

632 mapping = dict(spec) 

633 if "mfs" in mapping and "membership_functions" not in mapping: 

634 mapping = {**mapping, "membership_functions": mapping["mfs"]} 

635 for key in ("n_mfs", "mf_type", "init", "overlap", "margin", "range", "membership_functions"): 

636 if key in mapping and (mapping[key] is not None or key == "init"): 

637 config[key] = mapping[key] 

638 return config 

639 if isinstance(spec, str): 

640 config["mf_type"] = spec 

641 return config 

642 if isinstance(spec, int): 

643 config["n_mfs"] = int(spec) 

644 return config 

645 raise TypeError(f"Unsupported input configuration type: {type(spec)!r}") 

646 

647 def _build_model(self, X: np.ndarray, feature_names: list[str]) -> TSKANFIS: 

648 builder = ANFISBuilder() 

649 if self.input_specs_ is None: 

650 raise RuntimeError("Input specifications must be resolved before building the model.") 

651 for idx, name in enumerate(feature_names): 

652 column = X[:, idx] 

653 spec = self.input_specs_[idx] 

654 mf_list = spec.get("membership_functions") 

655 range_override = spec.get("range") 

656 if mf_list is not None: 

657 builder.input_mfs[name] = [cast(MembershipFunction, mf) for mf in mf_list] 

658 if range_override is not None: 

659 range_tuple = tuple(float(v) for v in range_override) 

660 if len(range_tuple) != 2: 

661 raise ValueError("range overrides must contain exactly two values") 

662 builder.input_ranges[name] = (range_tuple[0], range_tuple[1]) 

663 else: 

664 builder.input_ranges[name] = (float(np.min(column)), float(np.max(column))) 

665 continue 

666 if range_override is not None: 

667 range_tuple = tuple(float(v) for v in range_override) 

668 if len(range_tuple) != 2: 

669 raise ValueError("range overrides must contain exactly two values") 

670 rmin, rmax = range_tuple 

671 builder.add_input( 

672 name, 

673 float(rmin), 

674 float(rmax), 

675 int(spec["n_mfs"]), 

676 str(spec["mf_type"]), 

677 overlap=float(spec["overlap"]), 

678 ) 

679 else: 

680 init_strategy = spec.get("init") 

681 init_arg = None if init_strategy is None else str(init_strategy) 

682 builder.add_input_from_data( 

683 name, 

684 column, 

685 n_mfs=int(spec["n_mfs"]), 

686 mf_type=str(spec["mf_type"]), 

687 overlap=float(spec["overlap"]), 

688 margin=float(spec["margin"]), 

689 init=init_arg, 

690 random_state=self.random_state, 

691 ) 

692 builder.set_rules(self.rules) 

693 return builder.build() 

694 

695 def _instantiate_trainer(self) -> BaseTrainer: 

696 optimizer = self.optimizer if self.optimizer is not None else "hybrid" 

697 if isinstance(optimizer, BaseTrainer): 

698 trainer = deepcopy(optimizer) 

699 self._apply_runtime_overrides(trainer) 

700 return trainer 

701 if inspect.isclass(optimizer) and issubclass(optimizer, BaseTrainer): 

702 params = self._collect_trainer_params(optimizer) 

703 return optimizer(**params) 

704 if isinstance(optimizer, str): 

705 key = optimizer.lower() 

706 if key not in TRAINER_REGISTRY: 

707 supported = ", ".join(sorted(TRAINER_REGISTRY.keys())) 

708 raise ValueError(f"Unknown optimizer '{optimizer}'. Supported: {supported}") 

709 trainer_cls = TRAINER_REGISTRY[key] 

710 params = self._collect_trainer_params(trainer_cls) 

711 return trainer_cls(**params) 

712 raise TypeError("optimizer must be a string identifier, BaseTrainer instance, or BaseTrainer subclass") 

713 

714 def _collect_trainer_params(self, trainer_cls: type[BaseTrainer]) -> dict[str, Any]: 

715 params: dict[str, Any] = {} 

716 if self.optimizer_params is not None: 

717 params.update(self.optimizer_params) 

718 

719 overrides = { 

720 "learning_rate": self.learning_rate, 

721 "epochs": self.epochs, 

722 "batch_size": self.batch_size, 

723 "shuffle": self.shuffle, 

724 "verbose": self.verbose, 

725 "loss": self.loss, 

726 } 

727 for key, value in overrides.items(): 

728 if value is not None and key not in params: 

729 params[key] = value 

730 # Ensure boolean defaults propagate when value could be False 

731 if self.shuffle is not None: 

732 params.setdefault("shuffle", self.shuffle) 

733 params.setdefault("verbose", self.verbose) 

734 

735 sig = inspect.signature(trainer_cls) 

736 filtered: dict[str, Any] = {} 

737 for name in sig.parameters: 

738 if name == "self": 

739 continue 

740 if name in params: 

741 filtered[name] = params[name] 

742 return filtered 

743 

744 def _apply_runtime_overrides(self, trainer: BaseTrainer) -> None: 

745 for attr, value in ( 

746 ("learning_rate", self.learning_rate), 

747 ("epochs", self.epochs), 

748 ("batch_size", self.batch_size), 

749 ("shuffle", self.shuffle), 

750 ("verbose", self.verbose), 

751 ("loss", self.loss), 

752 ): 

753 if value is not None and hasattr(trainer, attr): 

754 setattr(trainer, attr, value) 

755 if hasattr(trainer, "verbose") and self.verbose is not None: 

756 trainer.verbose = self.verbose