Coverage for anfis_toolbox / classifier.py: 100%

436 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-05 18:47 -0300

1"""High-level classification estimator facade for ANFIS. 

2 

3``ANFISClassifier`` exposes a scikit-learn style API that bundles membership 

4function management, model construction, and trainer selection so downstream 

5code can focus on providing data and retrieving predictions. 

6""" 

7 

8from __future__ import annotations 

9 

10import inspect 

11import logging 

12import pickle # nosec B403 

13from collections.abc import Mapping, Sequence 

14from copy import deepcopy 

15from pathlib import Path 

16from typing import Any, TypeAlias, cast 

17 

18import numpy as np 

19import numpy.typing as npt 

20 

21from .builders import ANFISBuilder 

22from .estimator_utils import ( 

23 BaseEstimatorLike, 

24 ClassifierMixinLike, 

25 FittedMixin, 

26 check_is_fitted, 

27 ensure_2d_array, 

28 format_estimator_repr, 

29) 

30from .logging_config import enable_training_logs 

31from .losses import LossFunction 

32from .membership import MembershipFunction 

33from .metrics import ANFISMetrics, MetricValue 

34from .model import TrainingHistory, TSKANFISClassifier 

35from .optim import ( 

36 AdamTrainer, 

37 BaseTrainer, 

38 PSOTrainer, 

39 RMSPropTrainer, 

40 SGDTrainer, 

41) 

42from .optim import ( 

43 HybridAdamTrainer as _HybridAdamTrainer, 

44) 

45from .optim import ( 

46 HybridTrainer as _HybridTrainer, 

47) 

48 

49InputConfigValue: TypeAlias = Mapping[str, Any] | Sequence[Any] | MembershipFunction | str | int | None 

50NormalizedInputSpec: TypeAlias = dict[str, Any] 

51 

52TRAINER_REGISTRY: dict[str, type[BaseTrainer]] = { 

53 "sgd": SGDTrainer, 

54 "adam": AdamTrainer, 

55 "rmsprop": RMSPropTrainer, 

56 "pso": PSOTrainer, 

57} 

58 

59_UNSUPPORTED_TRAINERS: tuple[type[BaseTrainer], ...] = (_HybridTrainer, _HybridAdamTrainer) 

60 

61 

62def _ensure_training_logging(verbose: bool) -> None: 

63 if not verbose: 

64 return 

65 logger = logging.getLogger("anfis_toolbox") 

66 if logger.handlers: 

67 return 

68 enable_training_logs() 

69 

70 

71class ANFISClassifier(BaseEstimatorLike, FittedMixin, ClassifierMixinLike): 

72 """Adaptive Neuro-Fuzzy classifier with a scikit-learn style API. 

73 

74 The estimator manages membership-function synthesis, rule construction, and 

75 trainer selection so you can focus on calling :meth:`fit`, 

76 :meth:`predict`, :meth:`predict_proba`, and :meth:`evaluate` with familiar 

77 NumPy-like data structures. 

78 

79 Examples: 

80 -------- 

81 >>> clf = ANFISClassifier() 

82 >>> clf.fit(X, y) 

83 ANFISClassifier(...) 

84 >>> clf.predict([[0.1, -0.2]]) 

85 array([...]) 

86 

87 Parameters 

88 ---------- 

89 n_classes : int, optional 

90 Number of target classes. Must be >= 2 when provided. If omitted, the 

91 classifier infers the class count during the first call to ``fit``. 

92 n_mfs : int, default=3 

93 Default number of membership functions per input. 

94 mf_type : str, default="gaussian" 

95 Default membership function family applied when membership functions are 

96 inferred from data. 

97 init : {"grid", "fcm", "random", None}, default="grid" 

98 Strategy used when inferring membership functions from data. ``None`` 

99 falls back to ``"grid"``. 

100 overlap : float, default=0.5 

101 Controls overlap when generating membership functions automatically. 

102 margin : float, default=0.10 

103 Margin added around observed data ranges during grid initialization. 

104 inputs_config : Mapping, optional 

105 Per-input overrides. Keys may be feature names (when ``X`` is a 

106 :class:`pandas.DataFrame`) or integer indices. Values may be: 

107 

108 * ``dict`` with keys among ``{"n_mfs", "mf_type", "init", "overlap", 

109 "margin", "range", "membership_functions", "mfs"}``. 

110 * A list or tuple of membership function objects for full control. 

111 * ``None`` for defaults. 

112 random_state : int, optional 

113 Random state forwarded to initialization routines and stochastic 

114 optimizers. 

115 optimizer : str, BaseTrainer, type[BaseTrainer], or None, default="adam" 

116 Trainer identifier or instance used for fitting. Strings map to entries 

117 in :data:`TRAINER_REGISTRY`. ``None`` defaults to "adam". 

118 optimizer_params : Mapping, optional 

119 Additional keyword arguments forwarded to the trainer constructor. 

120 learning_rate, epochs, batch_size, shuffle, verbose : optional scalars 

121 Common trainer hyper-parameters provided for convenience. When the 

122 selected trainer supports the parameter it is included automatically. 

123 loss : str or LossFunction, optional 

124 Custom loss forwarded to trainers that expose a ``loss`` parameter. 

125 rules : Sequence[Sequence[int]] | None, optional 

126 Explicit fuzzy rule indices to use instead of the full Cartesian product. Each 

127 rule lists the membership-function index per input. ``None`` keeps the default 

128 exhaustive rule set. 

129 """ 

130 

131 def __init__( 

132 self, 

133 *, 

134 n_classes: int | None = None, 

135 n_mfs: int = 3, 

136 mf_type: str = "gaussian", 

137 init: str | None = "grid", 

138 overlap: float = 0.5, 

139 margin: float = 0.10, 

140 inputs_config: Mapping[Any, Any] | None = None, 

141 random_state: int | None = None, 

142 optimizer: str | BaseTrainer | type[BaseTrainer] | None = "adam", 

143 optimizer_params: Mapping[str, Any] | None = None, 

144 learning_rate: float | None = None, 

145 epochs: int | None = None, 

146 batch_size: int | None = None, 

147 shuffle: bool | None = None, 

148 verbose: bool = False, 

149 loss: LossFunction | str | None = None, 

150 rules: Sequence[Sequence[int]] | None = None, 

151 ) -> None: 

152 """Configure an :class:`ANFISClassifier` with the supplied hyper-parameters. 

153 

154 Parameters 

155 ---------- 

156 n_classes : int, optional 

157 Number of output classes. Must be at least two when provided. If 

158 omitted, the value is inferred from the training targets during 

159 the first ``fit`` call. 

160 n_mfs : int, default=3 

161 Default number of membership functions to allocate per input when 

162 inferred from data. 

163 mf_type : str, default="gaussian" 

164 Membership function family used for automatically generated 

165 membership functions. 

166 init : {"grid", "fcm", "random", None}, default="grid" 

167 Initialization strategy applied when synthesizing membership 

168 functions from the training data. ``None`` falls back to ``"grid"``. 

169 overlap : float, default=0.5 

170 Desired overlap between adjacent membership functions during 

171 automatic generation. 

172 margin : float, default=0.10 

173 Additional range padding applied around observed feature minima 

174 and maxima for grid initialization. 

175 inputs_config : Mapping, optional 

176 Per-feature overrides for the generated membership functions. 

177 Keys may be feature names (when ``X`` is a :class:`pandas.DataFrame`), 

178 integer indices, or ``"x{i}"`` aliases. Values may include dictionaries 

179 with membership-generation arguments, explicit membership function 

180 sequences, or ``None`` to retain defaults. 

181 random_state : int, optional 

182 Seed forwarded to stochastic initializers and optimizers. 

183 optimizer : str | BaseTrainer | type[BaseTrainer] | None, default="adam" 

184 Training algorithm identifier or instance. String aliases are looked 

185 up in :data:`TRAINER_REGISTRY`. ``None`` defaults to ``"adam"``. 

186 Hybrid variants that depend on least-squares refinements are limited 

187 to regression and raise ``ValueError`` when supplied here. 

188 optimizer_params : Mapping, optional 

189 Additional keyword arguments provided to the trainer constructor 

190 when a string alias or trainer class is supplied. 

191 learning_rate, epochs, batch_size, shuffle, verbose : optional 

192 Convenience hyper-parameters injected into the trainer whenever the 

193 chosen implementation accepts them. ``shuffle`` supports ``False`` 

194 to disable random shuffling. 

195 loss : str | LossFunction, optional 

196 Custom loss specification forwarded to trainers that expose a 

197 ``loss`` parameter. ``None`` resolves to cross-entropy. 

198 rules : Sequence[Sequence[int]] | None, optional 

199 Optional explicit fuzzy rule definitions. Each rule lists the 

200 membership-function index for each input. ``None`` uses the full 

201 Cartesian product of configured membership functions. 

202 """ 

203 if n_classes is not None and int(n_classes) < 2: 

204 raise ValueError("n_classes must be >= 2") 

205 self.n_classes: int | None = int(n_classes) if n_classes is not None else None 

206 self.n_mfs = int(n_mfs) 

207 self.mf_type = str(mf_type) 

208 self.init = None if init is None else str(init) 

209 self.overlap = float(overlap) 

210 self.margin = float(margin) 

211 self.inputs_config: dict[Any, InputConfigValue] | None = ( 

212 dict(inputs_config) if inputs_config is not None else None 

213 ) 

214 self.random_state = random_state 

215 self.optimizer = optimizer 

216 self.optimizer_params = dict(optimizer_params) if optimizer_params is not None else None 

217 self.learning_rate = learning_rate 

218 self.epochs = epochs 

219 self.batch_size = batch_size 

220 self.shuffle = shuffle 

221 self.verbose = verbose 

222 self.loss = loss 

223 self.rules = None if rules is None else tuple(tuple(int(idx) for idx in rule) for rule in rules) 

224 

225 # Fitted attributes (initialised during fit) 

226 self.model_: TSKANFISClassifier | None = None 

227 self.optimizer_: BaseTrainer | None = None 

228 self.feature_names_in_: list[str] | None = None 

229 self.n_features_in_: int | None = None 

230 self.training_history_: TrainingHistory | None = None 

231 self.input_specs_: list[NormalizedInputSpec] | None = None 

232 self.classes_: np.ndarray | None = None 

233 self._class_to_index_: dict[Any, int] | None = None 

234 self.rules_: list[tuple[int, ...]] | None = None 

235 

236 # ------------------------------------------------------------------ 

237 # Public API 

238 # ------------------------------------------------------------------ 

239 

240 def fit( 

241 self, 

242 X: npt.ArrayLike, 

243 y: npt.ArrayLike, 

244 *, 

245 validation_data: tuple[np.ndarray, np.ndarray] | None = None, 

246 validation_frequency: int = 1, 

247 verbose: bool | None = None, 

248 **fit_params: Any, 

249 ) -> ANFISClassifier: 

250 """Fit the classifier on labelled data. 

251 

252 Parameters 

253 ---------- 

254 X : array-like 

255 Training inputs with shape ``(n_samples, n_features)``. 

256 y : array-like 

257 Target labels. Accepts integer or string labels as well as one-hot 

258 matrices with shape ``(n_samples, n_classes)``. 

259 validation_data : tuple[np.ndarray, np.ndarray], optional 

260 Optional validation split supplied to the underlying trainer. 

261 Inputs and targets must already be numeric and share the same row 

262 count. 

263 validation_frequency : int, default=1 

264 Frequency (in epochs) at which validation metrics are computed when 

265 ``validation_data`` is provided. 

266 verbose : bool, optional 

267 Override the estimator's ``verbose`` flag for this fit call. When 

268 provided, the value is stored on the estimator and forwarded to the 

269 trainer configuration. 

270 **fit_params : Any 

271 Additional keyword arguments forwarded directly to the trainer 

272 ``fit`` method. 

273 

274 Returns: 

275 ------- 

276 ANFISClassifier 

277 Reference to ``self`` to enable fluent-style chaining. 

278 

279 Raises: 

280 ------ 

281 ValueError 

282 If the input arrays disagree on the number of samples or the label 

283 encoding is incompatible with the configured ``n_classes``. 

284 TypeError 

285 If the trainer ``fit`` implementation does not return a 

286 dictionary-style training history. 

287 """ 

288 X_arr, feature_names = ensure_2d_array(X) 

289 n_samples = X_arr.shape[0] 

290 y_encoded, classes = self._encode_targets(y, n_samples) 

291 

292 self.classes_ = classes 

293 self._class_to_index_ = {self._normalize_class_key(cls): idx for idx, cls in enumerate(classes.tolist())} 

294 

295 self.feature_names_in_ = feature_names 

296 self.n_features_in_ = X_arr.shape[1] 

297 self.input_specs_ = self._resolve_input_specs(feature_names) 

298 

299 if verbose is not None: 

300 self.verbose = bool(verbose) 

301 

302 _ensure_training_logging(self.verbose) 

303 if self.n_classes is None: 

304 raise RuntimeError("n_classes could not be inferred from the provided targets") 

305 self.model_ = self._build_model(X_arr, feature_names) 

306 trainer = self._instantiate_trainer() 

307 self.optimizer_ = trainer 

308 trainer_kwargs: dict[str, Any] = dict(fit_params) 

309 if validation_data is not None: 

310 trainer_kwargs.setdefault("validation_data", validation_data) 

311 if validation_data is not None or validation_frequency != 1: 

312 trainer_kwargs.setdefault("validation_frequency", validation_frequency) 

313 

314 history = trainer.fit(self.model_, X_arr, y_encoded, **trainer_kwargs) 

315 if not isinstance(history, dict): 

316 raise TypeError("Trainer.fit must return a TrainingHistory dictionary") 

317 self.training_history_ = history 

318 self.rules_ = self.model_.rules 

319 

320 self._mark_fitted() 

321 return self 

322 

323 def predict(self, X: npt.ArrayLike) -> np.ndarray: 

324 """Predict class labels for the provided samples. 

325 

326 Parameters 

327 ---------- 

328 X : array-like 

329 Samples to classify. One-dimensional arrays are treated as a single 

330 sample; two-dimensional arrays must have shape ``(n_samples, n_features)``. 

331 

332 Returns: 

333 ------- 

334 np.ndarray 

335 Predicted class labels with shape ``(n_samples,)``. 

336 

337 Raises: 

338 ------ 

339 RuntimeError 

340 If invoked before the estimator is fitted. 

341 ValueError 

342 When the supplied samples do not match the fitted feature count. 

343 """ 

344 check_is_fitted(self, attributes=["model_", "classes_"]) 

345 X_arr = np.asarray(X, dtype=float) 

346 if X_arr.ndim == 1: 

347 X_arr = X_arr.reshape(1, -1) 

348 else: 

349 X_arr, _ = ensure_2d_array(X) 

350 

351 if self.n_features_in_ is None: 

352 raise RuntimeError("Model must be fitted before calling predict.") 

353 if X_arr.shape[1] != self.n_features_in_: 

354 raise ValueError(f"Feature mismatch: expected {self.n_features_in_}, got {X_arr.shape[1]}.") 

355 model = self.model_ 

356 classes = self.classes_ 

357 if model is None or classes is None: 

358 raise RuntimeError("Model must be fitted before calling predict.") 

359 encoded = np.asarray(model.predict(X_arr), dtype=int) 

360 return cast(np.ndarray, np.asarray(classes)[encoded]) 

361 

362 def predict_proba(self, X: npt.ArrayLike) -> np.ndarray: 

363 """Predict class probabilities for the provided samples. 

364 

365 Parameters 

366 ---------- 

367 X : array-like 

368 Samples for which to estimate class probabilities. 

369 

370 Returns: 

371 ------- 

372 np.ndarray 

373 Matrix of shape ``(n_samples, n_classes)`` containing class 

374 probability estimates. 

375 

376 Raises: 

377 ------ 

378 RuntimeError 

379 If the estimator has not been fitted. 

380 ValueError 

381 If sample dimensionality does not match the fitted feature count. 

382 """ 

383 check_is_fitted(self, attributes=["model_"]) 

384 X_arr = np.asarray(X, dtype=float) 

385 if X_arr.ndim == 1: 

386 X_arr = X_arr.reshape(1, -1) 

387 else: 

388 X_arr, _ = ensure_2d_array(X) 

389 

390 if self.n_features_in_ is None: 

391 raise RuntimeError("Model must be fitted before calling predict_proba.") 

392 if X_arr.shape[1] != self.n_features_in_: 

393 raise ValueError(f"Feature mismatch: expected {self.n_features_in_}, got {X_arr.shape[1]}.") 

394 model = self.model_ 

395 if model is None: 

396 raise RuntimeError("Model must be fitted before calling predict_proba.") 

397 return np.asarray(model.predict_proba(X_arr), dtype=float) 

398 

399 def evaluate( 

400 self, 

401 X: npt.ArrayLike, 

402 y: npt.ArrayLike, 

403 *, 

404 return_dict: bool = True, 

405 print_results: bool = True, 

406 ) -> Mapping[str, MetricValue] | None: 

407 """Evaluate predictive performance on a labelled dataset. 

408 

409 Parameters 

410 ---------- 

411 X : array-like 

412 Evaluation inputs. 

413 y : array-like 

414 Ground-truth labels. Accepts integer labels or one-hot encodings. 

415 return_dict : bool, default=True 

416 When ``True`` return the computed metric dictionary; when ``False`` 

417 return ``None`` after optional printing. 

418 print_results : bool, default=True 

419 Emit a formatted summary to stdout. Set to ``False`` to suppress 

420 printing. 

421 

422 Returns: 

423 ------- 

424 Mapping[str, MetricValue] | None 

425 Dictionary containing accuracy, balanced accuracy, macro/micro 

426 precision/recall/F1 scores, and the confusion matrix when 

427 ``return_dict`` is ``True``; otherwise ``None``. 

428 

429 Raises: 

430 ------ 

431 RuntimeError 

432 If called before the estimator has been fitted. 

433 ValueError 

434 When ``X`` and ``y`` disagree on sample count or labels are 

435 incompatible with the configured class count. 

436 """ 

437 check_is_fitted(self, attributes=["model_"]) 

438 X_arr, _ = ensure_2d_array(X) 

439 encoded_targets, _ = self._encode_targets(y, X_arr.shape[0], allow_partial_classes=True) 

440 proba = self.predict_proba(X_arr) 

441 metrics: dict[str, MetricValue] = ANFISMetrics.classification_metrics(encoded_targets, proba) 

442 metrics.pop("log_loss", None) 

443 if print_results: 

444 

445 def _is_effectively_nan(value: Any) -> bool: 

446 if value is None: 

447 return True 

448 if isinstance(value, (float, np.floating)): 

449 return bool(np.isnan(value)) 

450 if isinstance(value, (int, np.integer)): 

451 return False 

452 if isinstance(value, np.ndarray): 

453 if value.size == 0: 

454 return False 

455 if np.issubdtype(value.dtype, np.number): 

456 return bool(np.isnan(value.astype(float)).all()) 

457 return False 

458 return False 

459 

460 print("ANFISClassifier evaluation:") # noqa: T201 

461 for key, value in metrics.items(): 

462 if _is_effectively_nan(value): 

463 continue 

464 if isinstance(value, (float, np.floating)): 

465 display_value = f"{float(value):.6f}" 

466 print(f" {key}: {display_value}") # noqa: T201 

467 elif isinstance(value, (int, np.integer)): 

468 print(f" {key}: {int(value)}") # noqa: T201 

469 elif isinstance(value, np.ndarray): 

470 array_repr = np.array2string(value, precision=6, suppress_small=True) 

471 if "\n" in array_repr: 

472 indented = "\n ".join(array_repr.splitlines()) 

473 print(f" {key}:\n {indented}") # noqa: T201 

474 else: 

475 print(f" {key}: {array_repr}") # noqa: T201 

476 else: 

477 print(f" {key}: {value}") # noqa: T201 

478 return metrics if return_dict else None 

479 

480 def get_rules(self) -> tuple[tuple[int, ...], ...]: 

481 """Return the fuzzy rule index combinations used by the fitted model. 

482 

483 Returns: 

484 ------- 

485 tuple[tuple[int, ...], ...] 

486 Immutable tuple describing each fuzzy rule as a per-input 

487 membership index. 

488 

489 Raises: 

490 ------ 

491 RuntimeError 

492 If invoked before ``fit`` completes. 

493 """ 

494 check_is_fitted(self, attributes=["rules_"]) 

495 if not self.rules_: 

496 return () 

497 return tuple(tuple(rule) for rule in self.rules_) 

498 

499 def save(self, filepath: str | Path) -> None: 

500 """Serialize this estimator (including fitted artefacts) to ``filepath``.""" 

501 path = Path(filepath) 

502 path.parent.mkdir(parents=True, exist_ok=True) 

503 with path.open("wb") as stream: 

504 pickle.dump(self, stream) # nosec B301 

505 

506 @classmethod 

507 def load(cls, filepath: str | Path) -> ANFISClassifier: 

508 """Load a pickled ``ANFISClassifier`` from ``filepath`` and validate its type.""" 

509 path = Path(filepath) 

510 with path.open("rb") as stream: 

511 estimator = pickle.load(stream) # nosec B301 

512 if not isinstance(estimator, cls): 

513 raise TypeError(f"Expected pickled {cls.__name__} instance, got {type(estimator).__name__}.") 

514 return estimator 

515 

516 def __repr__(self) -> str: 

517 """Return a formatted representation summarising configuration and fitted artefacts.""" 

518 return format_estimator_repr( 

519 type(self).__name__, 

520 self._repr_config_pairs(), 

521 self._repr_children_entries(), 

522 ) 

523 

524 # ------------------------------------------------------------------ 

525 # Helpers 

526 # ------------------------------------------------------------------ 

527 def _resolve_input_specs(self, feature_names: list[str]) -> list[NormalizedInputSpec]: 

528 resolved: list[NormalizedInputSpec] = [] 

529 for idx, name in enumerate(feature_names): 

530 spec = self._fetch_input_config(name, idx) 

531 resolved.append(self._normalize_input_spec(spec)) 

532 return resolved 

533 

534 def _fetch_input_config(self, name: str, index: int) -> InputConfigValue: 

535 if self.inputs_config is None: 

536 return None 

537 spec = self.inputs_config.get(name) 

538 if spec is not None: 

539 return spec 

540 spec = self.inputs_config.get(index) 

541 if spec is not None: 

542 return spec 

543 alt_key = f"x{index + 1}" 

544 return self.inputs_config.get(alt_key) 

545 

546 def _normalize_input_spec(self, spec: InputConfigValue) -> NormalizedInputSpec: 

547 config: NormalizedInputSpec = { 

548 "n_mfs": self.n_mfs, 

549 "mf_type": self.mf_type, 

550 "init": self.init, 

551 "overlap": self.overlap, 

552 "margin": self.margin, 

553 "range": None, 

554 "membership_functions": None, 

555 } 

556 if spec is None: 

557 return config 

558 if isinstance(spec, (list, tuple)) and all(isinstance(mf, MembershipFunction) for mf in spec): 

559 config["membership_functions"] = list(spec) 

560 return config 

561 if isinstance(spec, MembershipFunction): 

562 config["membership_functions"] = [spec] 

563 return config 

564 if isinstance(spec, Mapping): 

565 mapping = dict(spec) 

566 if "mfs" in mapping and "membership_functions" not in mapping: 

567 mapping = {**mapping, "membership_functions": mapping["mfs"]} 

568 for key in ("n_mfs", "mf_type", "init", "overlap", "margin", "range", "membership_functions"): 

569 if key in mapping and (mapping[key] is not None or key == "init"): 

570 config[key] = mapping[key] 

571 return config 

572 if isinstance(spec, str): 

573 config["mf_type"] = spec 

574 return config 

575 if isinstance(spec, int): 

576 config["n_mfs"] = int(spec) 

577 return config 

578 raise TypeError(f"Unsupported input configuration type: {type(spec)!r}") 

579 

580 def _build_model(self, X: np.ndarray, feature_names: list[str]) -> TSKANFISClassifier: 

581 builder = ANFISBuilder() 

582 if self.input_specs_ is None: 

583 raise RuntimeError("Input specifications must be resolved before building the model.") 

584 if self.n_classes is None: 

585 raise RuntimeError("Number of classes must be known before constructing the low-level model.") 

586 for idx, name in enumerate(feature_names): 

587 column = X[:, idx] 

588 spec = self.input_specs_[idx] 

589 mf_list = spec.get("membership_functions") 

590 range_override = spec.get("range") 

591 if mf_list is not None: 

592 builder.input_mfs[name] = [cast(MembershipFunction, mf) for mf in mf_list] 

593 if range_override is not None: 

594 range_tuple = tuple(float(v) for v in range_override) 

595 if len(range_tuple) != 2: 

596 raise ValueError("range overrides must contain exactly two values") 

597 builder.input_ranges[name] = (range_tuple[0], range_tuple[1]) 

598 else: 

599 builder.input_ranges[name] = (float(np.min(column)), float(np.max(column))) 

600 continue 

601 if range_override is not None: 

602 range_tuple = tuple(float(v) for v in range_override) 

603 if len(range_tuple) != 2: 

604 raise ValueError("range overrides must contain exactly two values") 

605 rmin, rmax = range_tuple 

606 builder.add_input( 

607 name, 

608 float(rmin), 

609 float(rmax), 

610 int(spec["n_mfs"]), 

611 str(spec["mf_type"]), 

612 overlap=float(spec["overlap"]), 

613 ) 

614 else: 

615 init_strategy = spec.get("init") 

616 init_arg = None if init_strategy is None else str(init_strategy) 

617 builder.add_input_from_data( 

618 name, 

619 column, 

620 n_mfs=int(spec["n_mfs"]), 

621 mf_type=str(spec["mf_type"]), 

622 overlap=float(spec["overlap"]), 

623 margin=float(spec["margin"]), 

624 init=init_arg, 

625 random_state=self.random_state, 

626 ) 

627 return TSKANFISClassifier( 

628 builder.input_mfs, 

629 n_classes=self.n_classes, 

630 random_state=self.random_state, 

631 rules=self.rules, 

632 ) 

633 

634 # ------------------------------------------------------------------ 

635 # Representation helpers 

636 # ------------------------------------------------------------------ 

637 def _repr_config_pairs(self) -> list[tuple[str, Any]]: 

638 optimizer_label = self._describe_optimizer_config(self.optimizer) 

639 pairs: list[tuple[str, Any]] = [ 

640 ("n_classes", self.n_classes), 

641 ("n_mfs", self.n_mfs), 

642 ("mf_type", self.mf_type), 

643 ("init", self.init), 

644 ("overlap", self.overlap), 

645 ("margin", self.margin), 

646 ("random_state", self.random_state), 

647 ("optimizer", optimizer_label), 

648 ("learning_rate", self.learning_rate), 

649 ("epochs", self.epochs), 

650 ("batch_size", self.batch_size), 

651 ("shuffle", self.shuffle), 

652 ("loss", self.loss), 

653 ] 

654 if self.rules is not None: 

655 pairs.append(("rules", f"preset:{len(self.rules)}")) 

656 if self.optimizer_params: 

657 pairs.append(("optimizer_params", self.optimizer_params)) 

658 return pairs 

659 

660 def _repr_children_entries(self) -> list[tuple[str, str]]: 

661 if not getattr(self, "is_fitted_", False): 

662 return [] 

663 

664 children: list[tuple[str, str]] = [] 

665 model = getattr(self, "model_", None) 

666 if model is not None: 

667 children.append(("model_", self._summarize_model(model))) 

668 

669 optimizer = getattr(self, "optimizer_", None) 

670 if optimizer is not None: 

671 children.append(("optimizer_", self._summarize_optimizer(optimizer))) 

672 

673 history = getattr(self, "training_history_", None) 

674 if isinstance(history, Mapping) and history: 

675 children.append(("training_history_", self._summarize_history(history))) 

676 

677 class_labels = getattr(self, "classes_", None) 

678 if class_labels is not None: 

679 labels = list(map(str, class_labels)) 

680 preview = labels if len(labels) <= 6 else labels[:5] + ["..."] 

681 children.append(("classes_", ", ".join(preview))) 

682 

683 learned_rules = getattr(self, "rules_", None) 

684 if learned_rules is not None: 

685 children.append(("rules_", f"{len(learned_rules)} learned")) 

686 

687 feature_names = getattr(self, "feature_names_in_", None) 

688 if feature_names is not None: 

689 children.append(("feature_names_in_", ", ".join(feature_names))) 

690 

691 return children 

692 

693 @staticmethod 

694 def _describe_optimizer_config(optimizer: Any) -> Any: 

695 if optimizer is None: 

696 return None 

697 if isinstance(optimizer, str): 

698 return optimizer 

699 if inspect.isclass(optimizer): 

700 return optimizer.__name__ 

701 if isinstance(optimizer, BaseTrainer): 

702 return type(optimizer).__name__ 

703 return repr(optimizer) 

704 

705 def _summarize_model(self, model: Any) -> str: 

706 name = type(model).__name__ 

707 parts = [name] 

708 n_inputs = getattr(model, "n_inputs", None) 

709 n_rules = getattr(model, "n_rules", None) 

710 n_classes = getattr(model, "n_classes", None) 

711 if n_inputs is not None: 

712 parts.append(f"n_inputs={n_inputs}") 

713 if n_rules is not None: 

714 parts.append(f"n_rules={n_rules}") 

715 if n_classes is not None: 

716 parts.append(f"n_classes={n_classes}") 

717 input_names = getattr(model, "input_names", None) 

718 if input_names: 

719 parts.append(f"inputs={list(input_names)}") 

720 mf_map = getattr(model, "membership_functions", None) 

721 if isinstance(mf_map, Mapping) and mf_map: 

722 counts = [len(mf_map[name]) for name in getattr(model, "input_names", mf_map.keys())] 

723 parts.append(f"mfs_per_input={counts}") 

724 return ", ".join(parts) 

725 

726 def _summarize_optimizer(self, optimizer: BaseTrainer) -> str: 

727 name = type(optimizer).__name__ 

728 fields: list[str] = [] 

729 for attr in ("learning_rate", "epochs", "batch_size", "shuffle", "verbose", "loss"): 

730 if hasattr(optimizer, attr): 

731 value = getattr(optimizer, attr) 

732 if value is not None: 

733 fields.append(f"{attr}={value!r}") 

734 if hasattr(optimizer, "__dict__") and not fields: 

735 return repr(optimizer) 

736 return f"{name}({', '.join(fields)})" if fields else name 

737 

738 @staticmethod 

739 def _summarize_history(history: Mapping[str, Any]) -> str: 

740 segments: list[str] = [] 

741 for key in ("train", "val", "validation", "metrics"): 

742 if key in history and isinstance(history[key], Sequence): 

743 series = history[key] 

744 length = len(series) 

745 if length == 0: 

746 segments.append(f"{key}=0") 

747 else: 

748 tail = series[-1] 

749 if isinstance(tail, (float, np.floating)): 

750 segments.append(f"{key}={length} (last={float(tail):.4f})") 

751 else: 

752 segments.append(f"{key}={length}") 

753 return ", ".join(segments) if segments else "{}" 

754 

755 def _instantiate_trainer(self) -> BaseTrainer: 

756 optimizer = self.optimizer if self.optimizer is not None else "adam" 

757 if isinstance(optimizer, BaseTrainer): 

758 if isinstance(optimizer, _UNSUPPORTED_TRAINERS): 

759 raise ValueError( 

760 "Hybrid-style trainers that rely on least-squares updates are not supported by ANFISClassifier. " 

761 "Choose among: " 

762 f"{', '.join(sorted(TRAINER_REGISTRY.keys()))}." 

763 ) 

764 trainer = deepcopy(optimizer) 

765 self._apply_runtime_overrides(trainer) 

766 return trainer 

767 if inspect.isclass(optimizer) and issubclass(optimizer, BaseTrainer): 

768 if issubclass(optimizer, _UNSUPPORTED_TRAINERS): 

769 raise ValueError( 

770 "Hybrid-style trainers that rely on least-squares updates are not supported by ANFISClassifier. " 

771 "Choose among: " 

772 f"{', '.join(sorted(TRAINER_REGISTRY.keys()))}." 

773 ) 

774 params = self._collect_trainer_params(optimizer) 

775 return optimizer(**params) 

776 if isinstance(optimizer, str): 

777 key = optimizer.lower() 

778 if key in {"hybrid", "hybrid_adam"}: 

779 raise ValueError( 

780 "Hybrid-style optimizers that combine least-squares with gradient descent are only available " 

781 "for regression. Supported classifier optimizers: " 

782 f"{', '.join(sorted(TRAINER_REGISTRY.keys()))}." 

783 ) 

784 if key not in TRAINER_REGISTRY: 

785 supported = ", ".join(sorted(TRAINER_REGISTRY.keys())) 

786 raise ValueError(f"Unknown optimizer '{optimizer}'. Supported: {supported}") 

787 trainer_cls = TRAINER_REGISTRY[key] 

788 params = self._collect_trainer_params(trainer_cls) 

789 return trainer_cls(**params) 

790 raise TypeError("optimizer must be a string identifier, BaseTrainer instance, or BaseTrainer subclass") 

791 

792 def _collect_trainer_params(self, trainer_cls: type[BaseTrainer]) -> dict[str, Any]: 

793 params: dict[str, Any] = {} 

794 if self.optimizer_params is not None: 

795 params.update(self.optimizer_params) 

796 

797 overrides: dict[str, Any] = { 

798 "learning_rate": self.learning_rate, 

799 "epochs": self.epochs, 

800 "batch_size": self.batch_size, 

801 "shuffle": self.shuffle, 

802 "verbose": self.verbose, 

803 "loss": self._resolved_loss_spec(), 

804 } 

805 for key, value in overrides.items(): 

806 if value is not None and key not in params: 

807 params[key] = value 

808 if self.shuffle is not None: 

809 params.setdefault("shuffle", self.shuffle) 

810 params.setdefault("verbose", self.verbose) 

811 

812 sig = inspect.signature(trainer_cls) 

813 filtered: dict[str, Any] = {} 

814 for name in sig.parameters: 

815 if name == "self": 

816 continue 

817 if name in params: 

818 filtered[name] = params[name] 

819 return filtered 

820 

821 def _apply_runtime_overrides(self, trainer: BaseTrainer) -> None: 

822 resolved_loss = self._resolved_loss_spec() 

823 for attr, value in ( 

824 ("learning_rate", self.learning_rate), 

825 ("epochs", self.epochs), 

826 ("batch_size", self.batch_size), 

827 ("shuffle", self.shuffle), 

828 ("verbose", self.verbose), 

829 ): 

830 if value is not None and hasattr(trainer, attr): 

831 setattr(trainer, attr, value) 

832 if hasattr(trainer, "loss") and resolved_loss is not None: 

833 trainer.loss = resolved_loss 

834 

835 def _encode_targets( 

836 self, 

837 y: npt.ArrayLike, 

838 n_samples: int, 

839 *, 

840 allow_partial_classes: bool = False, 

841 ) -> tuple[np.ndarray, np.ndarray]: 

842 y_arr: np.ndarray = np.asarray(y) 

843 if y_arr.ndim == 2: 

844 if y_arr.shape[0] != n_samples: 

845 raise ValueError("y must contain the same number of samples as X") 

846 n_classes = self.n_classes 

847 if n_classes is None: 

848 inferred = y_arr.shape[1] 

849 if inferred < 2: 

850 raise ValueError("One-hot targets must encode at least two classes for classification.") 

851 self.n_classes = inferred 

852 n_classes = inferred 

853 if y_arr.shape[1] != n_classes: 

854 raise ValueError(f"One-hot targets must have shape (n_samples, n_classes={n_classes}).") 

855 encoded = np.argmax(y_arr, axis=1).astype(int) 

856 classes = np.arange(n_classes) 

857 return encoded, classes 

858 if y_arr.ndim == 1: 

859 if y_arr.shape[0] != n_samples: 

860 raise ValueError("y must contain the same number of samples as X") 

861 classes = np.unique(y_arr) 

862 n_unique = classes.size 

863 if n_unique < 2 and not allow_partial_classes: 

864 raise ValueError("Classification targets must include at least two distinct classes.") 

865 n_classes = self.n_classes 

866 if n_classes is None: 

867 if n_unique < 2: 

868 raise ValueError("Classification targets must include at least two distinct classes.") 

869 self.n_classes = n_unique 

870 n_classes = n_unique 

871 if not allow_partial_classes and n_unique != n_classes: 

872 raise ValueError(f"y contains {n_unique} unique classes but estimator was configured for {n_classes}.") 

873 if n_unique > n_classes: 

874 raise ValueError( 

875 f"y contains {n_unique} unique classes which exceeds configured n_classes={n_classes}." 

876 ) 

877 normalized_classes = [self._normalize_class_key(cls) for cls in classes.tolist()] 

878 mapping = {cls: idx for idx, cls in enumerate(normalized_classes)} 

879 encoded = np.array([mapping[self._normalize_class_key(val)] for val in y_arr], dtype=int) 

880 return encoded, np.asarray(normalized_classes) 

881 raise ValueError("Target array must be 1-dimensional or a one-hot encoded 2D array.") 

882 

883 def _resolved_loss_spec(self) -> LossFunction | str: 

884 if self.loss is None: 

885 return "cross_entropy" 

886 return self.loss 

887 

888 def _more_tags(self) -> dict[str, Any]: # pragma: no cover - informational hook 

889 return { 

890 "estimator_type": "classifier", 

891 "requires_y": True, 

892 } 

893 

894 @staticmethod 

895 def _normalize_class_key(value: Any) -> Any: 

896 return value.item() if isinstance(value, np.generic) else value