Coverage for anfis_toolbox / estimator_utils.py: 100%

161 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-05 18:47 -0300

1"""Lightweight utilities for scikit-learn style estimators without external dependencies. 

2 

3This module provides a minimal subset of the scikit-learn estimator contract so that 

4high-level ANFIS interfaces can expose familiar methods (`fit`, `predict`, 

5`get_params`, `set_params`, etc.) without requiring scikit-learn as a runtime 

6dependency. The helpers here intentionally implement only the pieces we need 

7and keep them Numpy-centric for portability. 

8""" 

9 

10from __future__ import annotations 

11 

12from collections.abc import Iterable 

13from copy import deepcopy 

14from dataclasses import dataclass 

15from typing import Any 

16 

17import numpy as np 

18 

19try: # pragma: no cover - optional dependency 

20 from sklearn.utils._tags import Tags, TargetTags 

21 

22 _SKLEARN_TAGS_AVAILABLE = True 

23except Exception: # pragma: no cover - sklearn not installed 

24 Tags = None 

25 TargetTags = None 

26 _SKLEARN_TAGS_AVAILABLE = False 

27 

28__all__ = [ 

29 "BaseEstimatorLike", 

30 "RegressorMixinLike", 

31 "ClassifierMixinLike", 

32 "FittedMixin", 

33 "RuleInspectorMixin", 

34 "NotFittedError", 

35 "check_is_fitted", 

36 "ensure_2d_array", 

37 "ensure_vector", 

38 "infer_feature_names", 

39 "format_estimator_repr", 

40] 

41 

42 

43class NotFittedError(RuntimeError): 

44 """Exception raised when an estimator is used before fitting.""" 

45 

46 

47class BaseEstimatorLike: 

48 """Mixin implementing scikit-learn style parameter inspection. 

49 

50 Parameters are assumed to live on the instance `__dict__` and be declared in 

51 `__init__`. This matches the common sklearn design pattern and enables 

52 cloning/grid-search like workflows without relying on sklearn itself. 

53 """ 

54 

55 def get_params(self, deep: bool = True) -> dict[str, Any]: 

56 """Return estimator parameters following sklearn conventions.""" 

57 

58 def clone_param(value: Any) -> Any: 

59 if isinstance(value, dict): 

60 return {k: clone_param(v) for k, v in value.items()} 

61 if isinstance(value, (list, tuple)): 

62 return type(value)(clone_param(v) for v in value) 

63 # Primitive / numpy scalars 

64 if isinstance(value, (str, int, float, bool, type(None), np.generic)): 

65 return value 

66 # Fallback to deepcopy for custom objects 

67 return deepcopy(value) 

68 

69 return {key: clone_param(value) for key, value in self.__dict__.items() if not key.endswith("_")} 

70 

71 def set_params(self, **params: Any) -> BaseEstimatorLike: 

72 """Set estimator parameters and return self.""" 

73 for key, value in params.items(): 

74 if not hasattr(self, key): 

75 raise ValueError(f"Invalid parameter '{key}' for {type(self).__name__}.") 

76 setattr(self, key, value) 

77 return self 

78 

79 # ------------------------------------------------------------------ 

80 # scikit-learn compatibility hooks 

81 # ------------------------------------------------------------------ 

82 def __sklearn_tags__(self) -> dict[str, Any] | Any: 

83 """Return estimator capability tags expected by scikit-learn.""" 

84 merged: dict[str, Any] = {} 

85 more_tags = getattr(self, "_more_tags", None) 

86 if callable(more_tags): 

87 merged.update(more_tags()) 

88 extra_tags = getattr(self, "_sklearn_tags", None) 

89 if isinstance(extra_tags, dict): 

90 merged.update(extra_tags) 

91 

92 if _SKLEARN_TAGS_AVAILABLE: 

93 # Construct a Tags object with proper defaults 

94 estimator_type = merged.pop("estimator_type", None) 

95 non_deterministic = merged.pop("non_deterministic", False) 

96 requires_fit = merged.pop("requires_fit", True) 

97 requires_y = merged.pop("requires_y", True) 

98 

99 # Create target_tags 

100 target_tags = TargetTags(required=requires_y) 

101 

102 # Create Tags object 

103 tags = Tags( 

104 estimator_type=estimator_type, 

105 target_tags=target_tags, 

106 transformer_tags=None, 

107 classifier_tags=None, 

108 regressor_tags=None, 

109 non_deterministic=non_deterministic, 

110 requires_fit=requires_fit, 

111 ) 

112 

113 # Remaining keys are not recognised by the public Tags API; ignore gracefully. 

114 return tags 

115 

116 # Fallback lightweight representation when scikit-learn is not available. 

117 fallback = { 

118 "estimator_type": merged.pop("estimator_type", None), 

119 "non_deterministic": merged.pop("non_deterministic", False), 

120 "requires_y": merged.pop("requires_y", True), 

121 "requires_fit": merged.pop("requires_fit", True), 

122 } 

123 fallback.update(merged) 

124 return fallback 

125 

126 def __sklearn_is_fitted__(self) -> bool: 

127 """Expose estimator fitted state to scikit-learn utilities.""" 

128 return bool(getattr(self, "is_fitted_", False)) 

129 

130 

131class FittedMixin: 

132 """Mixin providing utility to guard against using estimators pre-fit.""" 

133 

134 def _mark_fitted(self) -> None: 

135 self.is_fitted_ = True 

136 

137 def _require_is_fitted(self, attributes: Iterable[str] | None = None) -> None: 

138 if not getattr(self, "is_fitted_", False): 

139 raise NotFittedError(f"{type(self).__name__} instance is not fitted yet.") 

140 if attributes: 

141 missing = [attr for attr in attributes if not hasattr(self, attr)] 

142 if missing: 

143 raise NotFittedError( 

144 f"Estimator {type(self).__name__} is missing fitted attribute(s): {', '.join(missing)}" 

145 ) 

146 

147 

148def check_is_fitted(estimator: FittedMixin, attributes: Iterable[str] | None = None) -> None: 

149 """Check if the estimator is fitted by verifying `is_fitted_` and optional attributes.""" 

150 estimator._require_is_fitted(attributes) 

151 

152 

153def format_estimator_repr( 

154 name: str, 

155 config_pairs: Iterable[tuple[str, Any]], 

156 children: Iterable[tuple[str, str]], 

157 *, 

158 ascii_only: bool = False, 

159) -> str: 

160 r"""Compose a tree-style ``__repr__`` string for estimators. 

161 

162 Parameters 

163 ---------- 

164 name : str 

165 The display name for the estimator (typically ``type(self).__name__``). 

166 config_pairs : Iterable[tuple[str, Any]] 

167 Sequence of ``(key, value)`` pairs describing configuration values. ``None`` 

168 values are omitted automatically. The values are rendered with ``repr`` to 

169 preserve type information (strings quoted, etc.). 

170 children : Iterable[tuple[str, str]] 

171 Sequence of ``(label, description)`` items describing child artefacts, such as 

172 fitted submodels or optimizers. Descriptions may contain newlines; they will 

173 be indented appropriately beneath the child label. 

174 ascii_only : bool, default=False 

175 When ``True`` use ASCII connectors (``|--``/``\--``) instead of box drawing 

176 characters. Automatically useful for environments that do not render Unicode 

177 well. 

178 

179 Returns: 

180 ------- 

181 str 

182 Multi-line string representation combining the header and child sections. 

183 """ 

184 config_fragments = [f"{key}={value!r}" for key, value in config_pairs if value is not None] 

185 header = f"{name}({', '.join(config_fragments)})" if config_fragments else f"{name}()" 

186 

187 child_list = list(children) 

188 if not child_list: 

189 return header 

190 

191 branch_mid, branch_last, pad_mid, pad_last = ( 

192 ("|--", "\\--", "| ", " ") if ascii_only else ("├─", "└─", "│ ", " ") 

193 ) 

194 

195 lines = [header] 

196 total = len(child_list) 

197 for index, (label, description) in enumerate(child_list): 

198 is_last = index == total - 1 

199 branch = branch_last if is_last else branch_mid 

200 pad = pad_last if is_last else pad_mid 

201 desc_lines = str(description).splitlines() or [""] 

202 first = desc_lines[0] 

203 lines.append(f"{branch} {label}: {first}") 

204 if len(desc_lines) > 1: 

205 padding = f"{pad} " 

206 for extra in desc_lines[1:]: 

207 lines.append(f"{padding}{extra}") 

208 

209 return "\n".join(lines) 

210 

211 

212class RuleInspectorMixin(FittedMixin): 

213 """Mixin that exposes ANFIS rule descriptors for fitted estimators.""" 

214 

215 def get_rules( 

216 self, 

217 *, 

218 include_membership_functions: bool = False, 

219 ) -> list[tuple[int, ...]] | list[dict[str, Any]]: 

220 """Return the fuzzy rules learned by the estimator. 

221 

222 Parameters 

223 ---------- 

224 include_membership_functions : bool, default=False 

225 When ``False`` (default), return a list of tuples with the membership-function 

226 indices per input. When ``True``, return a list of dictionaries describing each 

227 rule with input names, membership-function indices, and their corresponding 

228 membership function instances (when available). 

229 

230 Returns: 

231 ------- 

232 list 

233 Rule definitions either as tuples of integers (default) or as dictionaries with a 

234 rich description if ``include_membership_functions`` is ``True``. 

235 """ 

236 check_is_fitted(self, attributes=["rules_"]) 

237 

238 raw_rules = getattr(self, "rules_", None) or [] 

239 rule_tuples = [tuple(rule) for rule in raw_rules] 

240 

241 if not include_membership_functions: 

242 return rule_tuples 

243 

244 model = getattr(self, "model_", None) 

245 if model is None: 

246 raise NotFittedError(f"{type(self).__name__} instance is not fitted yet.") 

247 

248 membership_map = getattr(model, "membership_functions", {}) 

249 feature_names = list(getattr(self, "feature_names_in_", []) or membership_map.keys()) 

250 

251 descriptors: list[dict[str, Any]] = [] 

252 for rule_index, rule in enumerate(rule_tuples): 

253 antecedents: list[dict[str, Any]] = [] 

254 for input_index, mf_index in enumerate(rule): 

255 if input_index < len(feature_names): 

256 input_name = feature_names[input_index] 

257 else: # Fallback to positional naming 

258 input_name = f"x{input_index + 1}" 

259 

260 mf_list = membership_map.get(input_name, []) 

261 membership_fn = None 

262 if 0 <= mf_index < len(mf_list): 

263 membership_fn = mf_list[mf_index] 

264 

265 antecedents.append( 

266 { 

267 "input": input_name, 

268 "mf_index": int(mf_index), 

269 "membership_function": membership_fn, 

270 } 

271 ) 

272 

273 descriptors.append( 

274 { 

275 "index": rule_index, 

276 "rule": rule, 

277 "antecedents": antecedents, 

278 } 

279 ) 

280 

281 return descriptors 

282 

283 

284class RegressorMixinLike: 

285 """Mixin implementing a default `score` method for regressors.""" 

286 

287 def predict(self, X: np.ndarray) -> np.ndarray: # pragma: no cover - interface 

288 """Return predicted targets for ``X``.""" 

289 raise NotImplementedError 

290 

291 def score(self, X: np.ndarray, y: np.ndarray) -> float: 

292 """Return the coefficient of determination R^2 of the prediction.""" 

293 y_true = np.asarray(y, dtype=float).reshape(-1) 

294 y_pred = np.asarray(self.predict(X), dtype=float).reshape(-1) 

295 if y_true.shape != y_pred.shape: 

296 raise ValueError("Predicted values have a different shape than y.") 

297 ss_res = float(np.sum((y_true - y_pred) ** 2)) 

298 y_mean = float(np.mean(y_true)) 

299 ss_tot = float(np.sum((y_true - y_mean) ** 2)) 

300 if ss_tot == 0.0: 

301 return 0.0 

302 return 1.0 - ss_res / ss_tot 

303 

304 

305class ClassifierMixinLike: 

306 """Mixin implementing default `score` via simple accuracy.""" 

307 

308 def predict(self, X: np.ndarray) -> np.ndarray: # pragma: no cover - interface 

309 """Return predicted class labels for ``X``.""" 

310 raise NotImplementedError 

311 

312 def score(self, X: np.ndarray, y: np.ndarray) -> float: 

313 """Return the mean accuracy on the given test data and labels.""" 

314 y_true = np.asarray(y) 

315 y_pred = np.asarray(self.predict(X)) 

316 if y_true.shape != y_pred.shape: 

317 raise ValueError("Predicted values have a different shape than y.") 

318 if y_true.size == 0: 

319 return 0.0 

320 return float(np.mean(y_true == y_pred)) 

321 

322 

323@dataclass 

324class ValidationResult: 

325 X: np.ndarray 

326 y: np.ndarray | None 

327 feature_names: list[str] 

328 

329 

330def ensure_2d_array(X: Any) -> tuple[np.ndarray, list[str]]: 

331 """Validate and convert input data to a 2D float64 numpy array.""" 

332 if hasattr(X, "to_numpy"): 

333 values = X.to_numpy(dtype=float) 

334 names = getattr(X, "columns", None) 

335 feature_names = [str(col) for col in names] if names is not None else None 

336 else: 

337 values = np.asarray(X, dtype=float) 

338 feature_names = None 

339 

340 if values.ndim != 2: 

341 raise ValueError("Input data must be 2-dimensional (n_samples, n_features).") 

342 if feature_names is None: 

343 feature_names = [f"x{i + 1}" for i in range(values.shape[1])] 

344 

345 return values, feature_names 

346 

347 

348def ensure_vector(y: Any, *, allow_2d_column: bool = True) -> np.ndarray: 

349 """Validate and convert target data to a 1D numpy array.""" 

350 array = np.asarray(y) 

351 if array.ndim == 2: 

352 if array.shape[1] == 1 and allow_2d_column: 

353 array = array.reshape(-1) 

354 else: 

355 raise ValueError("Target array must be 1-dimensional or a column vector.") 

356 elif array.ndim != 1: 

357 raise ValueError("Target array must be 1-dimensional.") 

358 return array 

359 

360 

361# Backwards compatibility aliases (deprecated; prefer the public variants above) 

362_ensure_2d_array = ensure_2d_array 

363_ensure_vector = ensure_vector 

364 

365 

366def infer_feature_names(X: Any) -> list[str]: 

367 """Return feature names inferred from the input data structure.""" 

368 if hasattr(X, "columns"): 

369 return [str(col) for col in X.columns] 

370 X_arr = np.asarray(X) 

371 if X_arr.ndim != 2: 

372 raise ValueError("Expected 2D array-like input to infer feature names.") 

373 return [f"x{i + 1}" for i in range(X_arr.shape[1])]