Coverage for anfis_toolbox / optim / _utils.py: 100%
49 statements
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
« prev ^ index » next coverage.py v7.13.3, created at 2026-02-05 18:47 -0300
1from __future__ import annotations
3from collections.abc import Iterator
4from typing import Any
6import numpy as np
9def zeros_like_structure(params: Any) -> dict[str, Any]:
10 """Create a zero-structure matching model.get_parameters() format.
12 Returns a dict with:
13 - 'consequent': np.zeros_like(params['consequent'])
14 - 'membership': { name: [ {param_name: 0.0, ...} ] }
15 """
16 out: dict[str, Any] = {"consequent": np.zeros_like(params["consequent"]), "membership": {}}
17 for name, mf_list in params["membership"].items():
18 out["membership"][name] = []
19 for mf_params in mf_list:
20 out["membership"][name].append(dict.fromkeys(mf_params.keys(), 0.0))
21 return out
24def iterate_membership_params(
25 params_dict: Any,
26 grads_dict: Any | None = None,
27) -> Iterator[tuple[tuple[str, int, str], float, float | None]]:
28 """Iterate over membership parameters with their gradients.
30 See optim/parameter_utils.py for structure details.
31 """
32 for name in params_dict["membership"].keys():
33 for i, mf_dict in enumerate(params_dict["membership"][name]):
34 for key in mf_dict.keys():
35 param_val = float(params_dict["membership"][name][i][key])
36 grad_val = None
37 if grads_dict is not None:
38 grad_val = float(grads_dict["membership"][name][i][key])
39 yield (name, i, key), param_val, grad_val
42def iterate_membership_params_with_state(
43 params_dict: Any,
44 state_dict: Any,
45 grads_dict: Any | None = None,
46) -> Iterator[tuple[tuple[str, int, str], float, float, float | None]]:
47 """Iterate over membership parameters with state (for momentum-based optimizers)."""
48 for name in params_dict["membership"].keys():
49 for i, mf_dict in enumerate(params_dict["membership"][name]):
50 for key in mf_dict.keys():
51 param_val = float(params_dict["membership"][name][i][key])
52 state_val = float(state_dict["membership"][name][i][key])
53 grad_val = None
54 if grads_dict is not None:
55 grad_val = float(grads_dict["membership"][name][i][key])
56 yield (name, i, key), param_val, state_val, grad_val
59def update_membership_param(
60 params_dict: Any,
61 path: tuple[str, int, str],
62 value: float,
63) -> None:
64 name, i, key = path
65 params_dict["membership"][name][i][key] = float(value)
68def get_membership_param(
69 params_dict: Any,
70 path: tuple[str, int, str],
71) -> float:
72 name, i, key = path
73 return float(params_dict["membership"][name][i][key])
76def flatten_membership_params(params_dict: Any) -> tuple[np.ndarray, list[tuple[str, int, str]]]:
77 paths: list[tuple[str, int, str]] = []
78 values: list[float] = []
79 for name in params_dict["membership"].keys():
80 for i, mf_dict in enumerate(params_dict["membership"][name]):
81 for key in mf_dict.keys():
82 paths.append((name, i, key))
83 values.append(float(params_dict["membership"][name][i][key]))
84 return np.asarray(values, dtype=float), paths
87def unflatten_membership_params(
88 flat_array: np.ndarray,
89 paths: list[tuple[str, int, str]],
90 params_dict: Any,
91) -> None:
92 for idx, (name, i, key) in enumerate(paths):
93 params_dict["membership"][name][i][key] = float(flat_array[idx])
96__all__ = [
97 "zeros_like_structure",
98 "iterate_membership_params",
99 "iterate_membership_params_with_state",
100 "update_membership_param",
101 "get_membership_param",
102 "flatten_membership_params",
103 "unflatten_membership_params",
104]