Skip to content

Optimisers

Training strategies for highFIS estimators.

Trainers decouple the optimisation loop from the sklearn estimator layer. GradientTrainer implements standard single-phase mini-batch gradient descent. DGTrainer implements the three-phase data-guided (DG) training protocol required by DG-TSK and DG-ALETSK. FSRETrainer implements the three-phase FSRE training protocol required by FSRE-ADATSK.

Optimisers
  • GradientTrainer — standard mini-batch gradient descent.
  • DGTrainer — data-guided three-phase protocol for DG-TSK and DG-ALETSK.
  • FSRETrainer — three-phase FSRE protocol for FSRE-ADATSK.

BaseTrainer

Bases: ABC

Abstract base class for training strategies.

A trainer encapsulates the optimisation loop and is decoupled from the sklearn estimator. Concrete subclasses implement the full training protocol — e.g., single-phase gradient descent, three-phase DG training, or hybrid LSE / gradient procedures.

fit abstractmethod

Train model on (x, y) and return a history dictionary.

Parameters:

Name Type Description Default
model BaseTSK

The BaseTSK model to train.

required
x Tensor

Training input tensor of shape (n_samples, n_features).

required
y Tensor

Training target tensor.

required
x_val Tensor | None

Optional validation input tensor.

None
y_val Tensor | None

Optional validation target tensor.

None
metrics list[str] | None

Optional list of metric names to evaluate.

None

Returns:

Type Description
dict[str, Any]

A dictionary with training history. Keys depend on the concrete

dict[str, Any]

implementation. GradientTrainer returns the history

dict[str, Any]

dict from BaseTSK.fit(). DGTrainer

dict[str, Any]

returns a dict with keys "dg", "threshold", and

dict[str, Any]

"finetune".

Source code in highfis/optim/_base.py
@abstractmethod
def fit(
    self,
    model: BaseTSK,
    x: Tensor,
    y: Tensor,
    *,
    x_val: Tensor | None = None,
    y_val: Tensor | None = None,
    metrics: list[str] | None = None,
) -> dict[str, Any]:
    """Train *model* on *(x, y)* and return a history dictionary.

    Args:
        model: The BaseTSK model to train.
        x: Training input tensor of shape ``(n_samples, n_features)``.
        y: Training target tensor.
        x_val: Optional validation input tensor.
        y_val: Optional validation target tensor.
        metrics: Optional list of metric names to evaluate.

    Returns:
        A dictionary with training history.  Keys depend on the concrete
        implementation. GradientTrainer returns the history
        dict from BaseTSK.fit(). DGTrainer
        returns a dict with keys ``"dg"``, ``"threshold"``, and
        ``"finetune"``.
    """

DGTrainer

Bases: BaseTrainer

Three-phase trainer for DG-TSK and DG-ALETSK estimators.

Implements the data-guided (DG) training procedure described in Xue et al. (2023), which consists of three sequential phases:

  1. DG phase — Train gate parameters (λ, θ) and zero-order consequent parameters. For DG-TSK the antecedent MF parameters are frozen (P-FRB, paper §III-A). For DG-ALETSK the antecedents are updated together with the gates (CoCo-FRB, paper eq. 22).
  2. Threshold search — Grid-search over (zeta_lambda, zeta_theta) pairs to find the pruning thresholds that maximise held-out performance, optionally refitting first-order consequents via LSE.
  3. Fine-tune phase — Convert the model to first-order consequents and retrain with antecedent MFs and feature gates (λ) frozen.

Each phase is delegated to the corresponding method on the model: fit_dg_phase(), search_thresholds(), and fit_finetune().

Example
from highfis import DGTSKClassifier
from highfis.optim import DGTrainer

trainer = DGTrainer(dg_epochs=20, finetune_epochs=100, use_lse=True)
clf = DGTSKClassifier(trainer=trainer)
clf.fit(X_train, y_train, x_val=X_val, y_val=y_val)

Initialise a DG trainer.

Parameters:

Name Type Description Default
dg_epochs int

Epochs for the DG phase (phase 1).

10
dg_learning_rate float

SGD learning rate for the DG phase.

0.01
dg_batch_size int | None

Mini-batch size for the DG phase.

512
dg_shuffle bool

Reshuffle samples each epoch in the DG phase.

True
dg_patience int | None

Early-stopping patience for the DG phase.

20
dg_weight_decay float

L2 weight-decay for the DG phase.

1e-08
dg_ur_weight float

Uncertainty regularisation weight for the DG phase.

0.0
dg_ur_target float | None

Uncertainty regularisation target for the DG phase.

None
zeta_lambda list[float] | None

Grid of λ-threshold candidates for pruning. If None, uses [0.0, 0.25, 0.5, 0.75, 1.0].

None
zeta_theta list[float] | None

Grid of θ-threshold candidates. Same default.

None
use_lse bool

Refit first-order consequents via LSE during threshold search. Recommended (default True).

True
finetune_epochs int

Epochs for the fine-tune phase (phase 3).

200
finetune_learning_rate float

Adam learning rate for fine-tuning.

0.01
finetune_batch_size int | None

Mini-batch size for fine-tuning.

512
finetune_shuffle bool

Reshuffle samples each epoch during fine-tuning.

True
finetune_patience int | None

Early-stopping patience for fine-tuning.

20
finetune_restore_best bool

Restore best validation weights after fine-tuning.

True
finetune_weight_decay float

L2 weight-decay for fine-tuning.

1e-08
finetune_ur_weight float

Uncertainty regularisation weight for fine-tuning.

0.0
finetune_ur_target float | None

Uncertainty regularisation target for fine-tuning.

None
verbose bool | int

Verbosity level forwarded to all three phases.

False
loss Callable[..., Any] | None

Custom loss function f(output, target) -> scalar. None uses the model's built-in criterion.

None
optimizer_type str

Optimiser type used in all three phases. "sgd" (paper default) or "adamw" (AdamW with weight decay).

'sgd'
structural_pruning bool

If True (paper default), hard-prune the model architecture after threshold search. If False, only soft-prune (zero out gate parameters).

True
finetune_freeze_antecedents bool

If True (paper default), freeze MF parameters and feature-selection gates during fine-tuning. If False, only feature gates are frozen (MFs train freely).

True
Source code in highfis/optim/_dg.py
def __init__(
    self,
    *,
    # ── Phase 1: DG ──────────────────────────────────────────────────
    dg_epochs: int = 10,
    dg_learning_rate: float = 1e-2,
    dg_batch_size: int | None = 512,
    dg_shuffle: bool = True,
    dg_patience: int | None = 20,
    dg_weight_decay: float = 1e-8,
    dg_ur_weight: float = 0.0,
    dg_ur_target: float | None = None,
    # ── Phase 2: Threshold search ─────────────────────────────────────
    zeta_lambda: list[float] | None = None,
    zeta_theta: list[float] | None = None,
    use_lse: bool = True,
    # ── Phase 3: Fine-tune ────────────────────────────────────────────
    finetune_epochs: int = 200,
    finetune_learning_rate: float = 1e-2,
    finetune_batch_size: int | None = 512,
    finetune_shuffle: bool = True,
    finetune_patience: int | None = 20,
    finetune_restore_best: bool = True,
    finetune_weight_decay: float = 1e-8,
    finetune_ur_weight: float = 0.0,
    finetune_ur_target: float | None = None,
    # ── Shared ────────────────────────────────────────────────────────
    verbose: bool | int = False,
    loss: Callable[..., Any] | None = None,
    # ── Paper-conformance options ─────────────────────────────────────
    optimizer_type: str = "sgd",
    structural_pruning: bool = True,
    finetune_freeze_antecedents: bool = True,
) -> None:
    """Initialise a DG trainer.

    Args:
        dg_epochs: Epochs for the DG phase (phase 1).
        dg_learning_rate: SGD learning rate for the DG phase.
        dg_batch_size: Mini-batch size for the DG phase.
        dg_shuffle: Reshuffle samples each epoch in the DG phase.
        dg_patience: Early-stopping patience for the DG phase.
        dg_weight_decay: L2 weight-decay for the DG phase.
        dg_ur_weight: Uncertainty regularisation weight for the DG phase.
        dg_ur_target: Uncertainty regularisation target for the DG phase.
        zeta_lambda: Grid of λ-threshold candidates for pruning.  If
            ``None``, uses ``[0.0, 0.25, 0.5, 0.75, 1.0]``.
        zeta_theta: Grid of θ-threshold candidates.  Same default.
        use_lse: Refit first-order consequents via LSE during threshold
            search.  Recommended (default ``True``).
        finetune_epochs: Epochs for the fine-tune phase (phase 3).
        finetune_learning_rate: Adam learning rate for fine-tuning.
        finetune_batch_size: Mini-batch size for fine-tuning.
        finetune_shuffle: Reshuffle samples each epoch during fine-tuning.
        finetune_patience: Early-stopping patience for fine-tuning.
        finetune_restore_best: Restore best validation weights after
            fine-tuning.
        finetune_weight_decay: L2 weight-decay for fine-tuning.
        finetune_ur_weight: Uncertainty regularisation weight for fine-tuning.
        finetune_ur_target: Uncertainty regularisation target for fine-tuning.
        verbose: Verbosity level forwarded to all three phases.
        loss: Custom loss function ``f(output, target) -> scalar``.
            ``None`` uses the model's built-in criterion.
        optimizer_type: Optimiser type used in all three phases.  ``"sgd"``
            (paper default) or ``"adamw"`` (AdamW with weight decay).
        structural_pruning: If ``True`` (paper default), hard-prune the
            model architecture after threshold search.  If ``False``, only
            soft-prune (zero out gate parameters).
        finetune_freeze_antecedents: If ``True`` (paper default), freeze MF
            parameters and feature-selection gates during fine-tuning.  If
            ``False``, only feature gates are frozen (MFs train freely).
    """
    self.dg_epochs = dg_epochs
    self.dg_learning_rate = dg_learning_rate
    self.dg_batch_size = dg_batch_size
    self.dg_shuffle = dg_shuffle
    self.dg_patience = dg_patience
    self.dg_weight_decay = dg_weight_decay
    self.dg_ur_weight = dg_ur_weight
    self.dg_ur_target = dg_ur_target
    self.zeta_lambda = zeta_lambda
    self.zeta_theta = zeta_theta
    self.use_lse = use_lse
    self.finetune_epochs = finetune_epochs
    self.finetune_learning_rate = finetune_learning_rate
    self.finetune_batch_size = finetune_batch_size
    self.finetune_shuffle = finetune_shuffle
    self.finetune_patience = finetune_patience
    self.finetune_restore_best = finetune_restore_best
    self.finetune_weight_decay = finetune_weight_decay
    self.finetune_ur_weight = finetune_ur_weight
    self.finetune_ur_target = finetune_ur_target
    self.verbose = verbose
    self.loss = loss
    self.optimizer_type = optimizer_type
    self.structural_pruning = structural_pruning
    self.finetune_freeze_antecedents = finetune_freeze_antecedents

fit

Execute the three-phase DG training procedure.

Parameters:

Name Type Description Default
model BaseTSK

A DG-TSK or DG-ALETSK model instance, such as DGTSKClassifierModel, DGTSKRegressorModel, DGALETSKClassifierModel, or DGALETSKRegressorModel.

required
x Tensor

Training inputs of shape (N, D).

required
y Tensor

Training targets.

required
x_val Tensor | None

Validation inputs (used for threshold-search scoring). When None, training data is used for threshold selection.

None
y_val Tensor | None

Validation targets.

None
metrics list[str] | None

Optional list of metric names to evaluate.

None

Returns:

Type Description
dict[str, Any]

Dictionary with keys:

dict[str, Any]
  • "dg" — history dict from phase 1 (DG phase).
dict[str, Any]
  • "threshold" — result dict from search_thresholds().
dict[str, Any]
  • "finetune" — history dict from phase 3 (fine-tune phase).
Source code in highfis/optim/_dg.py
def fit(
    self,
    model: BaseTSK,
    x: Tensor,
    y: Tensor,
    *,
    x_val: Tensor | None = None,
    y_val: Tensor | None = None,
    metrics: list[str] | None = None,
) -> dict[str, Any]:
    """Execute the three-phase DG training procedure.

    Args:
        model: A DG-TSK or DG-ALETSK model instance, such as
            DGTSKClassifierModel, DGTSKRegressorModel,
            DGALETSKClassifierModel, or DGALETSKRegressorModel.
        x: Training inputs of shape ``(N, D)``.
        y: Training targets.
        x_val: Validation inputs (used for threshold-search scoring).
            When ``None``, training data is used for threshold selection.
        y_val: Validation targets.
        metrics: Optional list of metric names to evaluate.

    Returns:
        Dictionary with keys:

        - ``"dg"`` — history dict from phase 1 (DG phase).
        - ``"threshold"`` — result dict from search_thresholds().
        - ``"finetune"`` — history dict from phase 3 (fine-tune phase).
    """
    zeta_lambda = self.zeta_lambda if self.zeta_lambda is not None else _DEFAULT_ZETA
    zeta_theta = self.zeta_theta if self.zeta_theta is not None else _DEFAULT_ZETA

    if not isinstance(model, DGModelProtocol):
        raise TypeError("model must implement DGModelProtocol")

    # ── Phase 1: DG training ──────────────────────────────────────────
    dg_history = self._run_dg_phase(model, x, y, x_val, y_val, metrics)

    # ── Phase 2: Threshold search + pruning ───────────────────────────
    x_eval = x_val if x_val is not None else x
    y_eval = y_val if y_val is not None else y

    threshold_result: dict[str, Any] = model.search_thresholds(
        x,
        y,
        zeta_lambda=zeta_lambda,
        zeta_theta=zeta_theta,
        x_val=x_eval,
        y_val=y_eval,
        use_lse=bool(self.use_lse),
        inplace=True,
        verbose=bool(self.verbose),
        structural=bool(self.structural_pruning),
    )

    # Slice x to surviving features when structural pruning was applied.
    sf = threshold_result.get("surviving_feature_indices")
    if self.structural_pruning and sf is not None and len(sf) < x.shape[1]:
        x_ft: Tensor = x[:, sf]
        x_val_ft: Tensor | None = x_val[:, sf] if x_val is not None else None
    else:
        x_ft, x_val_ft = x, x_val

    # ── Phase 3: Fine-tune ────────────────────────────────────────────
    finetune_history = self._run_finetune_phase(model, x_ft, y, x_val_ft, y_val, metrics)

    return {
        "dg": dg_history,
        "threshold": threshold_result,
        "finetune": finetune_history,
    }

FSRETrainer

Bases: BaseTrainer

Three-phase trainer for FSRE-ADATSK classifier and regressor models.

Implements FSRE Algorithm, which consists of three sequential phases:

  1. FS phase — Train with CoCo-FRB and feature gates M(λ_d) active (paper eq. 21). After training, features with gate activation M(λ_d) > τ_λ are retained.
  2. RE phase — Expand to En-FRB and train with rule gates M(θ_r) active (paper eq. 22). After training, rules with gate activation M(θ_r) > τ_θ are retained. For classifiers, at least n_classes rules are kept.
  3. Fine-tune phase — Train the pruned model without gates (paper eq. 5).

Thresholds are computed directly from scalar zeta coefficients.

Example
from highfis import FSREADATSKClassifier
from highfis.optim import FSRETrainer

trainer = FSRETrainer(
    fs_epochs=10,
    re_epochs=10,
    finetune_epochs=100,
    zeta_lambda=0.5,
    zeta_theta=0.3,
)
clf = FSREADATSKClassifier(trainer=trainer)
clf.fit(X_train, y_train)

Initialise an FSRE trainer.

Parameters:

Name Type Description Default
fs_epochs int

Epochs for the FS phase (phase 1). Default 10.

10
fs_learning_rate float

Adam learning rate for the FS phase.

0.01
fs_batch_size int | None

Mini-batch size for the FS phase.

512
fs_shuffle bool

Reshuffle samples each epoch in the FS phase.

True
fs_patience int | None

Early-stopping patience for the FS phase.

20
fs_weight_decay float

L2 weight-decay for the FS phase.

1e-08
fs_ur_weight float

Uncertainty regularisation weight for the FS phase.

0.0
fs_ur_target float | None

Uncertainty regularisation target for the FS phase.

None
re_epochs int

Epochs for the RE phase (phase 2). Default 10.

10
re_learning_rate float

Adam learning rate for the RE phase.

0.01
re_batch_size int | None

Mini-batch size for the RE phase.

512
re_shuffle bool

Reshuffle samples each epoch in the RE phase.

True
re_patience int | None

Early-stopping patience for the RE phase.

20
re_weight_decay float

L2 weight-decay for the RE phase.

1e-08
re_ur_weight float

Uncertainty regularisation weight for the RE phase.

0.0
re_ur_target float | None

Uncertainty regularisation target for the RE phase.

None
finetune_epochs int

Epochs for the fine-tune phase (phase 3). Default 100.

100
finetune_learning_rate float

Adam learning rate for fine-tuning.

0.01
finetune_batch_size int | None

Mini-batch size for fine-tuning.

512
finetune_shuffle bool

Reshuffle samples each epoch during fine-tuning.

True
finetune_patience int | None

Early-stopping patience for fine-tuning.

20
finetune_restore_best bool

Restore best validation weights after fine-tuning.

True
finetune_weight_decay float

L2 weight-decay for fine-tuning.

1e-08
finetune_ur_weight float

Uncertainty regularisation weight for fine-tuning.

0.0
finetune_ur_target float | None

Uncertainty regularisation target for fine-tuning.

None
zeta_lambda float

Coefficient to compute the feature-selection threshold τ_λ (paper eq. 28). Larger values retain more features; 0.5 is recommended for low-dimensional data, 0.4 for high-dimensional data.

_DEFAULT_ZETA_LAMBDA
zeta_theta float

Coefficient to compute the rule-extraction threshold τ_θ (paper eq. 29). 0.3 is recommended for low-dimensional data, 0.5 for high-dimensional data.

_DEFAULT_ZETA_THETA
structural_pruning bool

If True (default), hard-prune the model architecture after each threshold step. If False, only the gate values are modified but the model structure is unchanged.

True
verbose bool | int

Verbosity level forwarded to all three phases.

False
loss Callable[..., Any] | None

Custom loss function f(output, target) -> scalar. None uses the model's built-in criterion.

None
Source code in highfis/optim/_fsre.py
def __init__(
    self,
    *,
    # ── Phase 1: FS ──────────────────────────────────────────────────
    fs_epochs: int = 10,
    fs_learning_rate: float = 1e-2,
    fs_batch_size: int | None = 512,
    fs_shuffle: bool = True,
    fs_patience: int | None = 20,
    fs_weight_decay: float = 1e-8,
    fs_ur_weight: float = 0.0,
    fs_ur_target: float | None = None,
    # ── Phase 2: RE ──────────────────────────────────────────────────
    re_epochs: int = 10,
    re_learning_rate: float = 1e-2,
    re_batch_size: int | None = 512,
    re_shuffle: bool = True,
    re_patience: int | None = 20,
    re_weight_decay: float = 1e-8,
    re_ur_weight: float = 0.0,
    re_ur_target: float | None = None,
    # ── Phase 3: Fine-tune ────────────────────────────────────────────
    finetune_epochs: int = 100,
    finetune_learning_rate: float = 1e-2,
    finetune_batch_size: int | None = 512,
    finetune_shuffle: bool = True,
    finetune_patience: int | None = 20,
    finetune_restore_best: bool = True,
    finetune_weight_decay: float = 1e-8,
    finetune_ur_weight: float = 0.0,
    finetune_ur_target: float | None = None,
    # ── Thresholds ────────────────────────────────────────────────────
    zeta_lambda: float = _DEFAULT_ZETA_LAMBDA,
    zeta_theta: float = _DEFAULT_ZETA_THETA,
    # ── Pruning ───────────────────────────────────────────────────────
    structural_pruning: bool = True,
    # ── Shared ────────────────────────────────────────────────────────
    verbose: bool | int = False,
    loss: Callable[..., Any] | None = None,
) -> None:
    """Initialise an FSRE trainer.

    Args:
        fs_epochs: Epochs for the FS phase (phase 1).  Default ``10``.
        fs_learning_rate: Adam learning rate for the FS phase.
        fs_batch_size: Mini-batch size for the FS phase.
        fs_shuffle: Reshuffle samples each epoch in the FS phase.
        fs_patience: Early-stopping patience for the FS phase.
        fs_weight_decay: L2 weight-decay for the FS phase.
        fs_ur_weight: Uncertainty regularisation weight for the FS phase.
        fs_ur_target: Uncertainty regularisation target for the FS phase.
        re_epochs: Epochs for the RE phase (phase 2).  Default ``10``.
        re_learning_rate: Adam learning rate for the RE phase.
        re_batch_size: Mini-batch size for the RE phase.
        re_shuffle: Reshuffle samples each epoch in the RE phase.
        re_patience: Early-stopping patience for the RE phase.
        re_weight_decay: L2 weight-decay for the RE phase.
        re_ur_weight: Uncertainty regularisation weight for the RE phase.
        re_ur_target: Uncertainty regularisation target for the RE phase.
        finetune_epochs: Epochs for the fine-tune phase (phase 3).
            Default ``100``.
        finetune_learning_rate: Adam learning rate for fine-tuning.
        finetune_batch_size: Mini-batch size for fine-tuning.
        finetune_shuffle: Reshuffle samples each epoch during fine-tuning.
        finetune_patience: Early-stopping patience for fine-tuning.
        finetune_restore_best: Restore best validation weights after
            fine-tuning.
        finetune_weight_decay: L2 weight-decay for fine-tuning.
        finetune_ur_weight: Uncertainty regularisation weight for
            fine-tuning.
        finetune_ur_target: Uncertainty regularisation target for
            fine-tuning.
        zeta_lambda: Coefficient to compute the feature-selection
            threshold τ_λ (paper eq. 28).  Larger values retain more
            features; ``0.5`` is recommended for low-dimensional data,
            ``0.4`` for high-dimensional data.
        zeta_theta: Coefficient to compute the rule-extraction threshold
            τ_θ (paper eq. 29).  ``0.3`` is recommended for
            low-dimensional data, ``0.5`` for high-dimensional data.
        structural_pruning: If ``True`` (default), hard-prune the model
            architecture after each threshold step.  If ``False``, only
            the gate values are modified but the model structure is
            unchanged.
        verbose: Verbosity level forwarded to all three phases.
        loss: Custom loss function ``f(output, target) -> scalar``.
            ``None`` uses the model's built-in criterion.
    """
    self.fs_epochs = fs_epochs
    self.fs_learning_rate = fs_learning_rate
    self.fs_batch_size = fs_batch_size
    self.fs_shuffle = fs_shuffle
    self.fs_patience = fs_patience
    self.fs_weight_decay = fs_weight_decay
    self.fs_ur_weight = fs_ur_weight
    self.fs_ur_target = fs_ur_target

    self.re_epochs = re_epochs
    self.re_learning_rate = re_learning_rate
    self.re_batch_size = re_batch_size
    self.re_shuffle = re_shuffle
    self.re_patience = re_patience
    self.re_weight_decay = re_weight_decay
    self.re_ur_weight = re_ur_weight
    self.re_ur_target = re_ur_target

    self.finetune_epochs = finetune_epochs
    self.finetune_learning_rate = finetune_learning_rate
    self.finetune_batch_size = finetune_batch_size
    self.finetune_shuffle = finetune_shuffle
    self.finetune_patience = finetune_patience
    self.finetune_restore_best = finetune_restore_best
    self.finetune_weight_decay = finetune_weight_decay
    self.finetune_ur_weight = finetune_ur_weight
    self.finetune_ur_target = finetune_ur_target

    self.zeta_lambda = zeta_lambda
    self.zeta_theta = zeta_theta
    self.structural_pruning = structural_pruning
    self.verbose = verbose
    self.loss = loss

fit

Execute the three-phase FSRE training procedure (Algorithm 1).

Parameters:

Name Type Description Default
model BaseTSK

An FSRE-ADATSK model instance, such as FSREADATSKClassifierModel or FSREADATSKRegressorModel.

required
x Tensor

Training inputs of shape (N, D).

required
y Tensor

Training targets.

required
x_val Tensor | None

Validation inputs (used for early stopping). When None, no external validation is performed.

None
y_val Tensor | None

Validation targets.

None
metrics list[str] | None

Optional list of metric names to evaluate.

None

Returns:

Type Description
dict[str, Any]

Dictionary with keys:

dict[str, Any]
  • "fs" — history dict from phase 1 (FS phase).
dict[str, Any]
  • "re" — history dict from phase 2 (RE phase).
dict[str, Any]
  • "finetune" — history dict from phase 3 (fine-tune phase).
dict[str, Any]
  • "surviving_feature_indices" — list of retained feature indices relative to the input x columns.
dict[str, Any]
  • "surviving_rule_indices" — list of retained rule indices relative to the En-FRB rule count after phase 2.
dict[str, Any]
  • "tau_lambda" — applied feature-selection threshold.
dict[str, Any]
  • "tau_theta" — applied rule-extraction threshold.
Source code in highfis/optim/_fsre.py
def fit(
    self,
    model: BaseTSK,
    x: Tensor,
    y: Tensor,
    *,
    x_val: Tensor | None = None,
    y_val: Tensor | None = None,
    metrics: list[str] | None = None,
) -> dict[str, Any]:
    """Execute the three-phase FSRE training procedure (Algorithm 1).

    Args:
        model: An FSRE-ADATSK model instance, such as
            FSREADATSKClassifierModel or FSREADATSKRegressorModel.
        x: Training inputs of shape ``(N, D)``.
        y: Training targets.
        x_val: Validation inputs (used for early stopping).
            When ``None``, no external validation is performed.
        y_val: Validation targets.
        metrics: Optional list of metric names to evaluate.

    Returns:
        Dictionary with keys:

        - ``"fs"`` — history dict from phase 1 (FS phase).
        - ``"re"`` — history dict from phase 2 (RE phase).
        - ``"finetune"`` — history dict from phase 3 (fine-tune phase).
        - ``"surviving_feature_indices"`` — list of retained feature indices
            relative to the input ``x`` columns.
        - ``"surviving_rule_indices"`` — list of retained rule indices
            relative to the En-FRB rule count after phase 2.
        - ``"tau_lambda"`` — applied feature-selection threshold.
        - ``"tau_theta"`` — applied rule-extraction threshold.
    """
    from ._gradient import GradientTrainer

    if not isinstance(model, FSREModelProtocol):
        raise TypeError("model must implement FSREModelProtocol")

    # ── Phase 1: Feature Selection ────────────────────────────────────
    # Set consequent to FS mode (only feature gates M(λ_d) active)
    model.set_consequent_mode("fs")
    fs_trainer = GradientTrainer(
        epochs=int(self.fs_epochs),
        learning_rate=float(self.fs_learning_rate),
        loss=self.loss,
        batch_size=self.fs_batch_size,
        shuffle=bool(self.fs_shuffle),
        ur_weight=float(self.fs_ur_weight),
        ur_target=self.fs_ur_target,
        verbose=self.verbose,
        patience=self.fs_patience,
        weight_decay=float(self.fs_weight_decay),
    )
    fs_history: dict[str, Any] = fs_trainer.fit(model, x, y, x_val=x_val, y_val=y_val, metrics=metrics)

    # ── Feature threshold & selection (paper eq. 28) ──────────────────
    feat_gates: Tensor = model.get_feature_gate_values()
    tau_lambda: float = _threshold_from_zeta(feat_gates, float(self.zeta_lambda))
    sf: list[int] = [i for i, v in enumerate(feat_gates.tolist()) if v > tau_lambda]
    if not sf:
        # Edge case: keep the single most-activated feature.
        sf = [int(feat_gates.argmax().item())]

    if self.structural_pruning:
        model.prune_to_features(sf)
        x_fs: Tensor = x[:, sf]
        x_val_fs: Tensor | None = x_val[:, sf] if x_val is not None else None
    else:
        x_fs, x_val_fs = x, x_val

    # ── Phase 2: Rule Extraction ──────────────────────────────────────
    # Expand to En-FRB (rebuilds rule_layer and consequent_layer with mode="re")
    model.expand_to_en_frb()
    re_trainer = GradientTrainer(
        epochs=int(self.re_epochs),
        learning_rate=float(self.re_learning_rate),
        loss=self.loss,
        batch_size=self.re_batch_size,
        shuffle=bool(self.re_shuffle),
        ur_weight=float(self.re_ur_weight),
        ur_target=self.re_ur_target,
        verbose=self.verbose,
        patience=self.re_patience,
        weight_decay=float(self.re_weight_decay),
    )
    re_history: dict[str, Any] = re_trainer.fit(model, x_fs, y, x_val=x_val_fs, y_val=y_val, metrics=metrics)

    # ── Rule threshold & selection (paper eq. 29) ─────────────────────
    rule_gates: Tensor = model.get_rule_gate_values()
    tau_theta: float = _threshold_from_zeta(rule_gates, float(self.zeta_theta))
    sr: list[int] = [r for r, v in enumerate(rule_gates.tolist()) if v > tau_theta]

    # Enforce lower bound: for classifiers ≥ n_classes rules (paper §III-C).
    from ..models._common import BaseTSKClassifierModel

    min_rules: int = 1
    if isinstance(model, BaseTSKClassifierModel):
        min_rules = model.n_classes
    if len(sr) < min_rules:
        top_indices: list[int] = torch.topk(rule_gates, min_rules).indices.tolist()
        sr = sorted(top_indices)

    if self.structural_pruning:
        model.prune_to_rules(sr)

    # Set consequent to finetune mode (no gates — plain TSK consequent)
    model.set_consequent_mode("finetune")
    ft_trainer = GradientTrainer(
        epochs=int(self.finetune_epochs),
        learning_rate=float(self.finetune_learning_rate),
        loss=self.loss,
        batch_size=self.finetune_batch_size,
        shuffle=bool(self.finetune_shuffle),
        ur_weight=float(self.finetune_ur_weight),
        ur_target=self.finetune_ur_target,
        verbose=self.verbose,
        patience=self.finetune_patience,
        restore_best=bool(self.finetune_restore_best),
        weight_decay=float(self.finetune_weight_decay),
    )
    finetune_history: dict[str, Any] = ft_trainer.fit(model, x_fs, y, x_val=x_val_fs, y_val=y_val, metrics=metrics)

    return {
        "fs": fs_history,
        "re": re_history,
        "finetune": finetune_history,
        "surviving_feature_indices": sf,
        "surviving_rule_indices": sr,
        "tau_lambda": tau_lambda,
        "tau_theta": tau_theta,
    }

GradientTrainer

Bases: BaseTrainer

Single-phase mini-batch gradient descent trainer.

This is the default trainer used by all standard highFIS estimators.

Example::

from highfis.optim import GradientTrainer

trainer = GradientTrainer(epochs=200, learning_rate=1e-3)
history = trainer.fit(model, x_train, y_train)

Initialise a gradient trainer.

Parameters:

Name Type Description Default
epochs int

Maximum number of full passes over the training data.

200
learning_rate float

Initial learning rate for the Adam optimiser.

0.01
batch_size int | None

Mini-batch size. None uses the full dataset.

512
shuffle bool

Reshuffle samples before each epoch.

True
patience int | None

Early-stopping patience. None disables early stopping.

20
restore_best bool

Restore the best validation weights after training.

True
weight_decay float

L2 weight-decay for consequent parameters.

1e-08
ur_weight float

Uncertainty regularisation weight.

0.0
ur_target float | None

Uncertainty regularisation target firing-level.

None
verbose bool | int

Verbosity level. - False / 0: silent. - True / 1: progress bar (tqdm). - 2: log ~10 points during training. - 3: log every epoch.

False
loss Callable[..., Any] | None

Custom loss function f(output, target) -> scalar. None uses the model's built-in criterion.

None
Source code in highfis/optim/_gradient.py
def __init__(
    self,
    *,
    epochs: int = 200,
    learning_rate: float = 1e-2,
    batch_size: int | None = 512,
    shuffle: bool = True,
    patience: int | None = 20,
    restore_best: bool = True,
    weight_decay: float = 1e-8,
    ur_weight: float = 0.0,
    ur_target: float | None = None,
    verbose: bool | int = False,
    loss: Callable[..., Any] | None = None,
) -> None:
    """Initialise a gradient trainer.

    Args:
        epochs: Maximum number of full passes over the training data.
        learning_rate: Initial learning rate for the Adam optimiser.
        batch_size: Mini-batch size. ``None`` uses the full dataset.
        shuffle: Reshuffle samples before each epoch.
        patience: Early-stopping patience.  ``None`` disables early
            stopping.
        restore_best: Restore the best validation weights after training.
        weight_decay: L2 weight-decay for consequent parameters.
        ur_weight: Uncertainty regularisation weight.
        ur_target: Uncertainty regularisation target firing-level.
        verbose: Verbosity level.
            - ``False`` / ``0``: silent.
            - ``True``  / ``1``: progress bar (tqdm).
            - ``2``: log ~10 points during training.
            - ``3``: log every epoch.
        loss: Custom loss function ``f(output, target) -> scalar``.
            ``None`` uses the model's built-in criterion.
    """
    self.epochs = epochs
    self.learning_rate = learning_rate
    self.batch_size = batch_size
    self.shuffle = shuffle
    self.patience = patience
    self.restore_best = restore_best
    self.weight_decay = weight_decay
    self.ur_weight = ur_weight
    self.ur_target = ur_target
    self.verbose = verbose
    self.loss = loss

fit

Train model for :attr:epochs epochs and return the history dict.

Source code in highfis/optim/_gradient.py
def fit(
    self,
    model: BaseTSK,
    x: Tensor,
    y: Tensor,
    *,
    x_val: Tensor | None = None,
    y_val: Tensor | None = None,
    metrics: list[str] | None = None,
    optimizer: torch.optim.Optimizer | None = None,
    scheduler: torch.optim.lr_scheduler.LRScheduler | Any = None,
) -> dict[str, Any]:
    """Train *model* for :attr:`epochs` epochs and return the history dict."""
    has_val = self._validate_fit_inputs(model, x, y, x_val, y_val, self.ur_weight, self.ur_target)
    train_criterion = self.loss or model.default_criterion()
    train_optimizer = self._build_optimizer(model, optimizer, self.learning_rate, self.weight_decay)

    train_loader = self._build_train_loader(x, y)
    metrics_list, maximize = self._resolve_metrics(metrics, model.task_type)

    history = self._init_history(has_val, metrics_list)
    best_metric = float("-inf")
    epochs_no_improve = 0
    best_state: dict[str, Any] | None = None
    verbose_level = _resolve_verbose(self.verbose)

    model.train()
    epoch_iterator, pbar = self._get_epoch_iterator(verbose_level)

    for epoch in epoch_iterator:
        history["stopped_epoch"] = epoch + 1
        epoch_main_loss, epoch_ur_loss, epoch_total_loss = self._run_minibatch_epoch(
            model,
            train_loader,
            train_criterion,
            train_optimizer,
            self.ur_weight,
            self.ur_target,
        )
        history["train"].append(epoch_total_loss)
        history["ur"].append(epoch_ur_loss)
        history["train_loss"].append(epoch_main_loss)
        history["train_ur_loss"].append(epoch_ur_loss)
        history["train_total_loss"].append(epoch_total_loss)

        # Evaluate metrics on train set
        if metrics_list:
            self._evaluate_epoch_metrics(model, x, y, metrics_list, history, "train")

        if has_val and x_val is not None and y_val is not None:
            best_metric, epochs_no_improve, best_state, should_stop = self._handle_validation_epoch(
                model,
                x_val,
                y_val,
                metrics_list,
                maximize,
                train_criterion,
                history,
                best_metric,
                epochs_no_improve,
                best_state,
                verbose_level,
                epoch,
                epoch_total_loss,
                pbar,
            )
            if should_stop:
                break
        else:
            self._log_epoch_no_val(model, epoch, self.epochs, epoch_total_loss, verbose_level, pbar)

        if scheduler is not None:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                val_loss_val = history["val_loss"][-1] if (has_val and history["val_loss"]) else epoch_main_loss
                scheduler.step(val_loss_val)
            else:
                scheduler.step()

        current_lr = train_optimizer.param_groups[0]["lr"]
        history["lr"].append(current_lr)

    self._finalize_training(model, best_state, pbar)

    return history