Coverage for anfis_toolbox / config.py: 100%
118 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
1"""Configuration utilities for ANFIS models."""
3import json
4import logging
5import pickle # nosec B403
6from collections.abc import Mapping, Sequence
7from copy import deepcopy
8from pathlib import Path
9from typing import Any, Protocol, TypedDict, cast
11from .builders import ANFISBuilder
12from .model import TSKANFIS
15class _InputConfig(TypedDict):
16 range_min: float
17 range_max: float
18 n_mfs: int
19 mf_type: str
20 overlap: float
23class _TrainingConfigRequired(TypedDict):
24 method: str
25 epochs: int
26 learning_rate: float
29class _TrainingConfigOptional(TypedDict, total=False):
30 verbose: bool
33class _TrainingConfig(_TrainingConfigRequired, _TrainingConfigOptional):
34 pass
37class _ConfigDict(TypedDict):
38 inputs: dict[str, _InputConfig]
39 training: _TrainingConfig
40 model_params: dict[str, Any]
43class _PresetConfig(TypedDict):
44 description: str
45 inputs: dict[str, _InputConfig]
46 training: _TrainingConfig
49class _SupportsParameters(Protocol):
50 @property
51 def parameters(self) -> dict[str, Any]: # pragma: no cover - protocol definition
52 """Return membership function parameters."""
53 ...
56_MembershipConfig = dict[str, list[dict[str, Any]]]
59class ANFISConfig:
60 """Configuration manager for ANFIS models."""
62 def __init__(self) -> None:
63 """Initialize configuration manager."""
64 self.config: _ConfigDict = {
65 "inputs": {},
66 "training": {"method": "hybrid", "epochs": 50, "learning_rate": 0.01, "verbose": False},
67 "model_params": {},
68 }
70 def add_input_config(
71 self,
72 name: str,
73 range_min: float,
74 range_max: float,
75 n_mfs: int = 3,
76 mf_type: str = "gaussian",
77 overlap: float = 0.5,
78 ) -> "ANFISConfig":
79 """Add input configuration.
81 Parameters:
82 name: Input variable name
83 range_min: Minimum input range
84 range_max: Maximum input range
85 n_mfs: Number of membership functions
86 mf_type: Type of membership functions
87 overlap: Overlap factor
89 Returns:
90 Self for method chaining
91 """
92 self.config["inputs"][name] = {
93 "range_min": range_min,
94 "range_max": range_max,
95 "n_mfs": n_mfs,
96 "mf_type": mf_type,
97 "overlap": overlap,
98 }
99 return self
101 def set_training_config(
102 self, method: str = "hybrid", epochs: int = 50, learning_rate: float = 0.01, verbose: bool = False
103 ) -> "ANFISConfig":
104 """Set training configuration.
106 Parameters:
107 method: Training method ('hybrid' or 'backprop')
108 epochs: Number of training epochs
109 learning_rate: Learning rate
110 verbose: Whether to show training progress
112 Returns:
113 Self for method chaining
114 """
115 self.config["training"].update(
116 {"method": method, "epochs": epochs, "learning_rate": learning_rate, "verbose": verbose}
117 )
118 return self
120 def build_model(self) -> TSKANFIS:
121 """Build ANFIS model from configuration.
123 Returns:
124 Configured ANFIS model
125 """
126 if not self.config["inputs"]:
127 raise ValueError("No inputs configured. Use add_input_config() first.")
129 builder = ANFISBuilder()
131 inputs = self.config["inputs"]
133 for name, params in inputs.items():
134 builder.add_input(
135 name=name,
136 range_min=float(params["range_min"]),
137 range_max=float(params["range_max"]),
138 n_mfs=int(params["n_mfs"]),
139 mf_type=str(params["mf_type"]),
140 overlap=float(params["overlap"]),
141 )
143 return builder.build()
145 def save(self, filepath: str | Path) -> None:
146 """Save configuration to JSON file.
148 Parameters:
149 filepath: Path to save configuration file
150 """
151 filepath = Path(filepath)
152 filepath.parent.mkdir(parents=True, exist_ok=True)
154 with open(filepath, "w") as f:
155 json.dump(self.config, f, indent=2)
157 @classmethod
158 def load(cls, filepath: str | Path) -> "ANFISConfig":
159 """Load configuration from JSON file.
161 Parameters:
162 filepath: Path to configuration file
164 Returns:
165 ANFISConfig object
166 """
167 with open(filepath) as f:
168 config_data = json.load(f)
170 config = cls()
171 config.config = cast(_ConfigDict, config_data)
172 return config
174 def to_dict(self) -> _ConfigDict:
175 """Convert configuration to dictionary.
177 Returns:
178 Configuration dictionary
179 """
180 return deepcopy(self.config)
182 def __repr__(self) -> str:
183 """String representation of configuration."""
184 inputs = self.config["inputs"]
185 n_inputs = len(inputs)
186 total_mfs = sum(int(inp["n_mfs"]) for inp in inputs.values())
188 return f"ANFISConfig(inputs={n_inputs}, total_mfs={total_mfs}, method={self.config['training']['method']})"
191class ANFISModelManager:
192 """Model management utilities for saving/loading trained ANFIS models."""
194 @staticmethod
195 def save_model(model: TSKANFIS, filepath: str | Path, include_config: bool = True) -> None:
196 """Save trained ANFIS model to file.
198 Parameters:
199 model: Trained ANFIS model
200 filepath: Path to save model file
201 include_config: Whether to save model configuration
202 """
203 filepath = Path(filepath)
204 filepath.parent.mkdir(parents=True, exist_ok=True)
206 # Save model using pickle
207 with open(filepath, "wb") as f:
208 pickle.dump(model, f) # nosec B301
210 # Save configuration if requested
211 if include_config:
212 config_path = filepath.with_suffix(".config.json")
213 try:
214 config = ANFISModelManager._extract_config(model)
215 with open(config_path, "w") as f:
216 json.dump(config, f, indent=2)
217 except Exception as e:
218 logging.warning("Could not save model configuration: %s", e)
220 @staticmethod
221 def load_model(filepath: str | Path) -> TSKANFIS:
222 """Load trained ANFIS model from file.
224 Parameters:
225 filepath: Path to model file
227 Returns:
228 Loaded ANFIS model
229 """
230 with open(filepath, "rb") as f:
231 model: TSKANFIS = pickle.load(f) # nosec B301
233 return model
235 @staticmethod
236 def _extract_config(model: TSKANFIS) -> dict[str, Any]:
237 """Extract configuration from trained model.
239 Parameters:
240 model: ANFIS model
242 Returns:
243 Model configuration dictionary
244 """
245 # Use standardized interface: both model and membership_layer have membership_functions property
246 membership_functions: Mapping[str, Sequence[_SupportsParameters]] = model.membership_functions
247 input_names = model.input_names
249 membership_config: _MembershipConfig = {}
250 config: dict[str, Any] = {
251 "model_info": {
252 "n_inputs": int(model.n_inputs),
253 "n_rules": int(model.n_rules),
254 "input_names": input_names,
255 },
256 "membership_functions": membership_config,
257 }
259 # Extract MF information from each input channel
260 for input_name, mfs in membership_functions.items():
261 membership_config[input_name] = []
263 for _i, mf in enumerate(mfs):
264 # Convert numpy scalars to native Python types for JSON serialization
265 parameters: dict[str, Any] = mf.parameters.copy()
266 for key, value in parameters.items():
267 if hasattr(value, "item"):
268 parameters[key] = value.item()
270 mf_info = {"type": mf.__class__.__name__, "parameters": parameters}
271 membership_config[input_name].append(mf_info)
273 return config
276# Predefined configurations for common use cases
277PREDEFINED_CONFIGS: dict[str, _PresetConfig] = {
278 "1d_function": {
279 "description": "Single input function approximation",
280 "inputs": {"x": {"range_min": -5, "range_max": 5, "n_mfs": 5, "mf_type": "gaussian", "overlap": 0.5}},
281 "training": {"method": "hybrid", "epochs": 100, "learning_rate": 0.01},
282 },
283 "2d_regression": {
284 "description": "Two-input regression problem",
285 "inputs": {
286 "x1": {"range_min": -2, "range_max": 2, "n_mfs": 3, "mf_type": "gaussian", "overlap": 0.5},
287 "x2": {"range_min": -2, "range_max": 2, "n_mfs": 3, "mf_type": "gaussian", "overlap": 0.5},
288 },
289 "training": {"method": "hybrid", "epochs": 50, "learning_rate": 0.01},
290 },
291 "control_system": {
292 "description": "Control system with error and error rate",
293 "inputs": {
294 "error": {"range_min": -1, "range_max": 1, "n_mfs": 5, "mf_type": "triangular", "overlap": 0.3},
295 "error_rate": {"range_min": -1, "range_max": 1, "n_mfs": 5, "mf_type": "triangular", "overlap": 0.3},
296 },
297 "training": {"method": "hybrid", "epochs": 75, "learning_rate": 0.015},
298 },
299 "time_series": {
300 "description": "Time series prediction with lag inputs",
301 "inputs": {
302 "lag1": {"range_min": -3, "range_max": 3, "n_mfs": 4, "mf_type": "gaussian", "overlap": 0.4},
303 "lag2": {"range_min": -3, "range_max": 3, "n_mfs": 4, "mf_type": "gaussian", "overlap": 0.4},
304 "lag3": {"range_min": -3, "range_max": 3, "n_mfs": 3, "mf_type": "gaussian", "overlap": 0.4},
305 },
306 "training": {"method": "hybrid", "epochs": 60, "learning_rate": 0.008},
307 },
308}
311def create_config_from_preset(preset_name: str) -> ANFISConfig:
312 """Create configuration from predefined preset.
314 Parameters:
315 preset_name: Name of predefined configuration
317 Returns:
318 ANFISConfig object
320 Raises:
321 ValueError: If preset name not found
322 """
323 if preset_name not in PREDEFINED_CONFIGS:
324 available = list(PREDEFINED_CONFIGS.keys())
325 raise ValueError(f"Preset '{preset_name}' not found. Available presets: {available}")
327 preset = PREDEFINED_CONFIGS[preset_name]
328 config = ANFISConfig()
330 # Add inputs
331 for name, params in preset["inputs"].items():
332 config.add_input_config(name, **params)
334 # Set training parameters
335 config.set_training_config(**preset["training"])
337 return config
340def list_presets() -> dict[str, str]:
341 """List available predefined configurations.
343 Returns:
344 Dictionary mapping preset names to descriptions
345 """
346 return {name: info["description"] for name, info in PREDEFINED_CONFIGS.items()}