Coverage for anfis_toolbox / config.py: 100%

118 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-05 18:47 -0300

1"""Configuration utilities for ANFIS models.""" 

2 

3import json 

4import logging 

5import pickle # nosec B403 

6from collections.abc import Mapping, Sequence 

7from copy import deepcopy 

8from pathlib import Path 

9from typing import Any, Protocol, TypedDict, cast 

10 

11from .builders import ANFISBuilder 

12from .model import TSKANFIS 

13 

14 

15class _InputConfig(TypedDict): 

16 range_min: float 

17 range_max: float 

18 n_mfs: int 

19 mf_type: str 

20 overlap: float 

21 

22 

23class _TrainingConfigRequired(TypedDict): 

24 method: str 

25 epochs: int 

26 learning_rate: float 

27 

28 

29class _TrainingConfigOptional(TypedDict, total=False): 

30 verbose: bool 

31 

32 

33class _TrainingConfig(_TrainingConfigRequired, _TrainingConfigOptional): 

34 pass 

35 

36 

37class _ConfigDict(TypedDict): 

38 inputs: dict[str, _InputConfig] 

39 training: _TrainingConfig 

40 model_params: dict[str, Any] 

41 

42 

43class _PresetConfig(TypedDict): 

44 description: str 

45 inputs: dict[str, _InputConfig] 

46 training: _TrainingConfig 

47 

48 

49class _SupportsParameters(Protocol): 

50 @property 

51 def parameters(self) -> dict[str, Any]: # pragma: no cover - protocol definition 

52 """Return membership function parameters.""" 

53 ... 

54 

55 

56_MembershipConfig = dict[str, list[dict[str, Any]]] 

57 

58 

59class ANFISConfig: 

60 """Configuration manager for ANFIS models.""" 

61 

62 def __init__(self) -> None: 

63 """Initialize configuration manager.""" 

64 self.config: _ConfigDict = { 

65 "inputs": {}, 

66 "training": {"method": "hybrid", "epochs": 50, "learning_rate": 0.01, "verbose": False}, 

67 "model_params": {}, 

68 } 

69 

70 def add_input_config( 

71 self, 

72 name: str, 

73 range_min: float, 

74 range_max: float, 

75 n_mfs: int = 3, 

76 mf_type: str = "gaussian", 

77 overlap: float = 0.5, 

78 ) -> "ANFISConfig": 

79 """Add input configuration. 

80 

81 Parameters: 

82 name: Input variable name 

83 range_min: Minimum input range 

84 range_max: Maximum input range 

85 n_mfs: Number of membership functions 

86 mf_type: Type of membership functions 

87 overlap: Overlap factor 

88 

89 Returns: 

90 Self for method chaining 

91 """ 

92 self.config["inputs"][name] = { 

93 "range_min": range_min, 

94 "range_max": range_max, 

95 "n_mfs": n_mfs, 

96 "mf_type": mf_type, 

97 "overlap": overlap, 

98 } 

99 return self 

100 

101 def set_training_config( 

102 self, method: str = "hybrid", epochs: int = 50, learning_rate: float = 0.01, verbose: bool = False 

103 ) -> "ANFISConfig": 

104 """Set training configuration. 

105 

106 Parameters: 

107 method: Training method ('hybrid' or 'backprop') 

108 epochs: Number of training epochs 

109 learning_rate: Learning rate 

110 verbose: Whether to show training progress 

111 

112 Returns: 

113 Self for method chaining 

114 """ 

115 self.config["training"].update( 

116 {"method": method, "epochs": epochs, "learning_rate": learning_rate, "verbose": verbose} 

117 ) 

118 return self 

119 

120 def build_model(self) -> TSKANFIS: 

121 """Build ANFIS model from configuration. 

122 

123 Returns: 

124 Configured ANFIS model 

125 """ 

126 if not self.config["inputs"]: 

127 raise ValueError("No inputs configured. Use add_input_config() first.") 

128 

129 builder = ANFISBuilder() 

130 

131 inputs = self.config["inputs"] 

132 

133 for name, params in inputs.items(): 

134 builder.add_input( 

135 name=name, 

136 range_min=float(params["range_min"]), 

137 range_max=float(params["range_max"]), 

138 n_mfs=int(params["n_mfs"]), 

139 mf_type=str(params["mf_type"]), 

140 overlap=float(params["overlap"]), 

141 ) 

142 

143 return builder.build() 

144 

145 def save(self, filepath: str | Path) -> None: 

146 """Save configuration to JSON file. 

147 

148 Parameters: 

149 filepath: Path to save configuration file 

150 """ 

151 filepath = Path(filepath) 

152 filepath.parent.mkdir(parents=True, exist_ok=True) 

153 

154 with open(filepath, "w") as f: 

155 json.dump(self.config, f, indent=2) 

156 

157 @classmethod 

158 def load(cls, filepath: str | Path) -> "ANFISConfig": 

159 """Load configuration from JSON file. 

160 

161 Parameters: 

162 filepath: Path to configuration file 

163 

164 Returns: 

165 ANFISConfig object 

166 """ 

167 with open(filepath) as f: 

168 config_data = json.load(f) 

169 

170 config = cls() 

171 config.config = cast(_ConfigDict, config_data) 

172 return config 

173 

174 def to_dict(self) -> _ConfigDict: 

175 """Convert configuration to dictionary. 

176 

177 Returns: 

178 Configuration dictionary 

179 """ 

180 return deepcopy(self.config) 

181 

182 def __repr__(self) -> str: 

183 """String representation of configuration.""" 

184 inputs = self.config["inputs"] 

185 n_inputs = len(inputs) 

186 total_mfs = sum(int(inp["n_mfs"]) for inp in inputs.values()) 

187 

188 return f"ANFISConfig(inputs={n_inputs}, total_mfs={total_mfs}, method={self.config['training']['method']})" 

189 

190 

191class ANFISModelManager: 

192 """Model management utilities for saving/loading trained ANFIS models.""" 

193 

194 @staticmethod 

195 def save_model(model: TSKANFIS, filepath: str | Path, include_config: bool = True) -> None: 

196 """Save trained ANFIS model to file. 

197 

198 Parameters: 

199 model: Trained ANFIS model 

200 filepath: Path to save model file 

201 include_config: Whether to save model configuration 

202 """ 

203 filepath = Path(filepath) 

204 filepath.parent.mkdir(parents=True, exist_ok=True) 

205 

206 # Save model using pickle 

207 with open(filepath, "wb") as f: 

208 pickle.dump(model, f) # nosec B301 

209 

210 # Save configuration if requested 

211 if include_config: 

212 config_path = filepath.with_suffix(".config.json") 

213 try: 

214 config = ANFISModelManager._extract_config(model) 

215 with open(config_path, "w") as f: 

216 json.dump(config, f, indent=2) 

217 except Exception as e: 

218 logging.warning("Could not save model configuration: %s", e) 

219 

220 @staticmethod 

221 def load_model(filepath: str | Path) -> TSKANFIS: 

222 """Load trained ANFIS model from file. 

223 

224 Parameters: 

225 filepath: Path to model file 

226 

227 Returns: 

228 Loaded ANFIS model 

229 """ 

230 with open(filepath, "rb") as f: 

231 model: TSKANFIS = pickle.load(f) # nosec B301 

232 

233 return model 

234 

235 @staticmethod 

236 def _extract_config(model: TSKANFIS) -> dict[str, Any]: 

237 """Extract configuration from trained model. 

238 

239 Parameters: 

240 model: ANFIS model 

241 

242 Returns: 

243 Model configuration dictionary 

244 """ 

245 # Use standardized interface: both model and membership_layer have membership_functions property 

246 membership_functions: Mapping[str, Sequence[_SupportsParameters]] = model.membership_functions 

247 input_names = model.input_names 

248 

249 membership_config: _MembershipConfig = {} 

250 config: dict[str, Any] = { 

251 "model_info": { 

252 "n_inputs": int(model.n_inputs), 

253 "n_rules": int(model.n_rules), 

254 "input_names": input_names, 

255 }, 

256 "membership_functions": membership_config, 

257 } 

258 

259 # Extract MF information from each input channel 

260 for input_name, mfs in membership_functions.items(): 

261 membership_config[input_name] = [] 

262 

263 for _i, mf in enumerate(mfs): 

264 # Convert numpy scalars to native Python types for JSON serialization 

265 parameters: dict[str, Any] = mf.parameters.copy() 

266 for key, value in parameters.items(): 

267 if hasattr(value, "item"): 

268 parameters[key] = value.item() 

269 

270 mf_info = {"type": mf.__class__.__name__, "parameters": parameters} 

271 membership_config[input_name].append(mf_info) 

272 

273 return config 

274 

275 

276# Predefined configurations for common use cases 

277PREDEFINED_CONFIGS: dict[str, _PresetConfig] = { 

278 "1d_function": { 

279 "description": "Single input function approximation", 

280 "inputs": {"x": {"range_min": -5, "range_max": 5, "n_mfs": 5, "mf_type": "gaussian", "overlap": 0.5}}, 

281 "training": {"method": "hybrid", "epochs": 100, "learning_rate": 0.01}, 

282 }, 

283 "2d_regression": { 

284 "description": "Two-input regression problem", 

285 "inputs": { 

286 "x1": {"range_min": -2, "range_max": 2, "n_mfs": 3, "mf_type": "gaussian", "overlap": 0.5}, 

287 "x2": {"range_min": -2, "range_max": 2, "n_mfs": 3, "mf_type": "gaussian", "overlap": 0.5}, 

288 }, 

289 "training": {"method": "hybrid", "epochs": 50, "learning_rate": 0.01}, 

290 }, 

291 "control_system": { 

292 "description": "Control system with error and error rate", 

293 "inputs": { 

294 "error": {"range_min": -1, "range_max": 1, "n_mfs": 5, "mf_type": "triangular", "overlap": 0.3}, 

295 "error_rate": {"range_min": -1, "range_max": 1, "n_mfs": 5, "mf_type": "triangular", "overlap": 0.3}, 

296 }, 

297 "training": {"method": "hybrid", "epochs": 75, "learning_rate": 0.015}, 

298 }, 

299 "time_series": { 

300 "description": "Time series prediction with lag inputs", 

301 "inputs": { 

302 "lag1": {"range_min": -3, "range_max": 3, "n_mfs": 4, "mf_type": "gaussian", "overlap": 0.4}, 

303 "lag2": {"range_min": -3, "range_max": 3, "n_mfs": 4, "mf_type": "gaussian", "overlap": 0.4}, 

304 "lag3": {"range_min": -3, "range_max": 3, "n_mfs": 3, "mf_type": "gaussian", "overlap": 0.4}, 

305 }, 

306 "training": {"method": "hybrid", "epochs": 60, "learning_rate": 0.008}, 

307 }, 

308} 

309 

310 

311def create_config_from_preset(preset_name: str) -> ANFISConfig: 

312 """Create configuration from predefined preset. 

313 

314 Parameters: 

315 preset_name: Name of predefined configuration 

316 

317 Returns: 

318 ANFISConfig object 

319 

320 Raises: 

321 ValueError: If preset name not found 

322 """ 

323 if preset_name not in PREDEFINED_CONFIGS: 

324 available = list(PREDEFINED_CONFIGS.keys()) 

325 raise ValueError(f"Preset '{preset_name}' not found. Available presets: {available}") 

326 

327 preset = PREDEFINED_CONFIGS[preset_name] 

328 config = ANFISConfig() 

329 

330 # Add inputs 

331 for name, params in preset["inputs"].items(): 

332 config.add_input_config(name, **params) 

333 

334 # Set training parameters 

335 config.set_training_config(**preset["training"]) 

336 

337 return config 

338 

339 

340def list_presets() -> dict[str, str]: 

341 """List available predefined configurations. 

342 

343 Returns: 

344 Dictionary mapping preset names to descriptions 

345 """ 

346 return {name: info["description"] for name, info in PREDEFINED_CONFIGS.items()}