Coverage for anfis_toolbox / optim / _utils.py: 100%

49 statements  

« prev     ^ index     » next       coverage.py v7.13.3, created at 2026-02-05 18:47 -0300

1from __future__ import annotations 

2 

3from collections.abc import Iterator 

4from typing import Any 

5 

6import numpy as np 

7 

8 

9def zeros_like_structure(params: Any) -> dict[str, Any]: 

10 """Create a zero-structure matching model.get_parameters() format. 

11 

12 Returns a dict with: 

13 - 'consequent': np.zeros_like(params['consequent']) 

14 - 'membership': { name: [ {param_name: 0.0, ...} ] } 

15 """ 

16 out: dict[str, Any] = {"consequent": np.zeros_like(params["consequent"]), "membership": {}} 

17 for name, mf_list in params["membership"].items(): 

18 out["membership"][name] = [] 

19 for mf_params in mf_list: 

20 out["membership"][name].append(dict.fromkeys(mf_params.keys(), 0.0)) 

21 return out 

22 

23 

24def iterate_membership_params( 

25 params_dict: Any, 

26 grads_dict: Any | None = None, 

27) -> Iterator[tuple[tuple[str, int, str], float, float | None]]: 

28 """Iterate over membership parameters with their gradients. 

29 

30 See optim/parameter_utils.py for structure details. 

31 """ 

32 for name in params_dict["membership"].keys(): 

33 for i, mf_dict in enumerate(params_dict["membership"][name]): 

34 for key in mf_dict.keys(): 

35 param_val = float(params_dict["membership"][name][i][key]) 

36 grad_val = None 

37 if grads_dict is not None: 

38 grad_val = float(grads_dict["membership"][name][i][key]) 

39 yield (name, i, key), param_val, grad_val 

40 

41 

42def iterate_membership_params_with_state( 

43 params_dict: Any, 

44 state_dict: Any, 

45 grads_dict: Any | None = None, 

46) -> Iterator[tuple[tuple[str, int, str], float, float, float | None]]: 

47 """Iterate over membership parameters with state (for momentum-based optimizers).""" 

48 for name in params_dict["membership"].keys(): 

49 for i, mf_dict in enumerate(params_dict["membership"][name]): 

50 for key in mf_dict.keys(): 

51 param_val = float(params_dict["membership"][name][i][key]) 

52 state_val = float(state_dict["membership"][name][i][key]) 

53 grad_val = None 

54 if grads_dict is not None: 

55 grad_val = float(grads_dict["membership"][name][i][key]) 

56 yield (name, i, key), param_val, state_val, grad_val 

57 

58 

59def update_membership_param( 

60 params_dict: Any, 

61 path: tuple[str, int, str], 

62 value: float, 

63) -> None: 

64 name, i, key = path 

65 params_dict["membership"][name][i][key] = float(value) 

66 

67 

68def get_membership_param( 

69 params_dict: Any, 

70 path: tuple[str, int, str], 

71) -> float: 

72 name, i, key = path 

73 return float(params_dict["membership"][name][i][key]) 

74 

75 

76def flatten_membership_params(params_dict: Any) -> tuple[np.ndarray, list[tuple[str, int, str]]]: 

77 paths: list[tuple[str, int, str]] = [] 

78 values: list[float] = [] 

79 for name in params_dict["membership"].keys(): 

80 for i, mf_dict in enumerate(params_dict["membership"][name]): 

81 for key in mf_dict.keys(): 

82 paths.append((name, i, key)) 

83 values.append(float(params_dict["membership"][name][i][key])) 

84 return np.asarray(values, dtype=float), paths 

85 

86 

87def unflatten_membership_params( 

88 flat_array: np.ndarray, 

89 paths: list[tuple[str, int, str]], 

90 params_dict: Any, 

91) -> None: 

92 for idx, (name, i, key) in enumerate(paths): 

93 params_dict["membership"][name][i][key] = float(flat_array[idx]) 

94 

95 

96__all__ = [ 

97 "zeros_like_structure", 

98 "iterate_membership_params", 

99 "iterate_membership_params_with_state", 

100 "update_membership_param", 

101 "get_membership_param", 

102 "flatten_membership_params", 

103 "unflatten_membership_params", 

104]