Persistence
Versioned checkpoint helpers for estimator persistence.
This module provides functions to save and load highFIS estimator checkpoints
using PyTorch serialization. Checkpoints store estimator constructor
parameters, the fitted model state dict, and sklearn-compatible fit metadata
(n_features_in_, feature_names_in_, classes_, etc.).
The checkpoint schema is versioned. CHECKPOINT_FORMAT identifies the
payload type and CHECKPOINT_FORMAT_VERSION is an integer string incremented
only when the checkpoint schema itself changes — independent of the package
version. validate_checkpoint_payload enforces both so that incompatible
checkpoints are rejected before any state is restored.
Checkpoint schema keys:
format— must equalCHECKPOINT_FORMAT.format_version— must equalCHECKPOINT_FORMAT_VERSION.estimator_class— class name of the estimator that created the checkpoint.estimator_params— constructor kwargs used to recreate the estimator.model_init— metadata needed to rebuild the model architecture.model_state_dict— serialized model weights.fitted_attrs— sklearn fit metadata.history(optional) — per-epoch training history.
Examples:
>>> from highfis.persistence import load_checkpoint, validate_checkpoint_payload
>>> ckpt = load_checkpoint("artifacts/clf.pt")
>>> validate_checkpoint_payload(ckpt, expected_estimator_class="HTSKClassifier")
deserialize_input_mfs
Reconstruct input_mfs from a serialized config dict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
dict[str, list[dict[str, Any]]]
|
A dict as returned by :func: |
required |
Returns:
| Type | Description |
|---|---|
dict[str, list[Any]]
|
A mapping of feature name to a list of |
dict[str, list[Any]]
|
class: |
dict[str, list[Any]]
|
for passing to :class: |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the config contains an unrecognised MF type name. |
Source code in highfis/persistence.py
load_checkpoint
Load a checkpoint dictionary from disk into CPU memory.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
Source file path of a checkpoint previously saved with
:func: |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
The deserialized checkpoint dictionary. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If the loaded payload is not a dictionary. |
Source code in highfis/persistence.py
save_checkpoint
Save a checkpoint dictionary to disk using PyTorch serialization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
Target file path. Parent directories are created automatically if they do not exist. |
required |
checkpoint
|
dict[str, Any]
|
Dictionary payload containing estimator state. |
required |
Source code in highfis/persistence.py
serialize_input_mfs
Serialize an nn.ModuleDict of membership functions to a plain dict.
Converts each membership function to a {"type": classname, "params": {...}}
entry so the checkpoint contains only primitive Python types and tensors,
making it compatible with torch.load(..., weights_only=True).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_mfs
|
Any
|
The |
required |
Returns:
| Type | Description |
|---|---|
dict[str, list[dict[str, Any]]]
|
A JSON-serializable dict mapping feature name to a list of MF configs. |
Source code in highfis/persistence.py
validate_checkpoint_payload
Validate a loaded checkpoint payload before estimator reconstruction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
checkpoint
|
dict[str, Any]
|
Checkpoint dictionary returned by
:func: |
required |
expected_estimator_class
|
str
|
Name of the estimator class that is expected to own this checkpoint. |
required |
Raises:
| Type | Description |
|---|---|
ValueError
|
If format, format_version, or estimator_class do not match expected values, or if required keys are missing. |