Skip to content

Base

Base TSK model that factors the common antecedent-defuzzification pipeline.

This module defines BaseTSK, the abstract foundation for all TSK fuzzy models in highFIS. It factors out the shared antecedent pipeline, defuzzifier, and training loop so concrete subclasses can focus on task-specific consequent layers and loss criteria.

The forward pipeline executes four sequential steps:

  1. highfis.layers.MembershipLayer — evaluates membership functions for each input feature.
  2. highfis.layers.RuleLayer — computes rule firing strengths via a configurable rule base and T-norm.
  3. Defuzzifier — normalizes firing strengths to probability-like weights (default: highfis.defuzzifiers.SoftmaxLogDefuzzifier).
  4. ConsequentLayer — produces the final output from the inputs and the normalized rule weights.

Concrete subclasses must implement:

  • BaseTSK._build_consequent_layer — return the task-specific consequent module.
  • BaseTSK._default_criterion — return the default loss function.

Optional overridable hooks:

  • BaseTSK._compute_loss — customize target preparation or loss composition.
  • BaseTSK._evaluate_validation — customize the validation metric used for early stopping.

BaseTSK

Bases: nn.Module

Abstract base for TSK fuzzy models.

Subclasses must implement :meth:_build_consequent_layer and :meth:_default_criterion. Optionally override :meth:_compute_loss and :meth:_evaluate_validation for task-specific logic.

Initialize the TSK pipeline layers.

Parameters:

Name Type Description Default
input_mfs Mapping[str, Sequence[MembershipFunction]]

Mapping from feature names to sequences of :class:~highfis.memberships.MembershipFunction objects. Must not be empty.

required
rule_base str

Rule-base construction strategy. Supported values: "cartesian" (all MF combinations), "coco" (same-index compact), "en" (enhanced FRB), or "custom" (explicit rules via rules).

'cartesian'
t_norm str

Built-in T-norm name. Ignored when t_norm_fn is provided. Common values: "prod", "gmean", "min", "dombi", "yager".

'gmean'
t_norm_fn TNormFn | None

Optional custom T-norm callable. When provided, t_norm is internally set to "prod" and the rule layer applies this function instead.

None
rules Sequence[Sequence[int]] | None

Explicit rule index sequences. Required when rule_base is "custom".

None
defuzzifier nn.Module | None

Normalization module applied to raw rule firing strengths. Defaults to :class:~highfis.defuzzifiers.SoftmaxLogDefuzzifier.

None
consequent_batch_norm bool

If True, insert a :class:~torch.nn.BatchNorm1d layer on the inputs before the consequent computation.

False

Raises:

Type Description
ValueError

If input_mfs is empty.

Source code in highfis/base.py
def __init__(
    self,
    input_mfs: Mapping[str, Sequence[MembershipFunction]],
    *,
    rule_base: str = "cartesian",
    t_norm: str = "gmean",
    t_norm_fn: TNormFn | None = None,
    rules: Sequence[Sequence[int]] | None = None,
    defuzzifier: nn.Module | None = None,
    consequent_batch_norm: bool = False,
) -> None:
    """Initialize the TSK pipeline layers.

    Args:
        input_mfs: Mapping from feature names to sequences of
            :class:`~highfis.memberships.MembershipFunction` objects.
            Must not be empty.
        rule_base: Rule-base construction strategy.  Supported values:
            ``"cartesian"`` (all MF combinations), ``"coco"``
            (same-index compact), ``"en"`` (enhanced FRB), or
            ``"custom"`` (explicit rules via *rules*).
        t_norm: Built-in T-norm name.  Ignored when *t_norm_fn* is
            provided.  Common values: ``"prod"``, ``"gmean"``,
            ``"min"``, ``"dombi"``, ``"yager"``.
        t_norm_fn: Optional custom T-norm callable.  When provided,
            *t_norm* is internally set to ``"prod"`` and the rule
            layer applies this function instead.
        rules: Explicit rule index sequences.  Required when
            *rule_base* is ``"custom"``.
        defuzzifier: Normalization module applied to raw rule firing
            strengths.  Defaults to
            :class:`~highfis.defuzzifiers.SoftmaxLogDefuzzifier`.
        consequent_batch_norm: If ``True``, insert a
            :class:`~torch.nn.BatchNorm1d` layer on the inputs before
            the consequent computation.

    Raises:
        ValueError: If *input_mfs* is empty.
    """
    super().__init__()
    if not input_mfs:
        raise ValueError("input_mfs must not be empty")

    self.input_mfs = input_mfs
    self.input_names = list(input_mfs.keys())
    self.n_inputs = len(self.input_names)
    mf_per_input = [len(input_mfs[name]) for name in self.input_names]

    self.membership_layer = MembershipLayer(input_mfs)
    self.rule_layer = RuleLayer(
        self.input_names,
        mf_per_input,
        rules=rules,
        rule_base=rule_base,
        t_norm=t_norm if t_norm_fn is None else "prod",
        t_norm_fn=t_norm_fn,
    )
    self.n_rules = self.rule_layer.n_rules
    self.defuzzifier = defuzzifier or SoftmaxLogDefuzzifier()
    self.consequent_batch_norm = bool(consequent_batch_norm)
    self.consequent_bn = nn.BatchNorm1d(self.n_inputs) if self.consequent_batch_norm else None
    self.consequent_layer = self._build_consequent_layer()
    self.logger = logging.getLogger(f"{self.__class__.__module__}.{self.__class__.__name__}")
    if not self.logger.handlers:
        stream_handler = logging.StreamHandler(sys.stdout)
        stream_handler.setFormatter(logging.Formatter("%(message)s"))
        self.logger.addHandler(stream_handler)
        self.logger.setLevel(logging.INFO)
        self.logger.propagate = False

fit

Train the model with optional early stopping.

When x_val and y_val are provided the model evaluates a task-specific metric (via :meth:_evaluate_validation) after every epoch and applies early stopping when the metric has not improved for patience consecutive epochs. By default the best model weights from validation are restored when restore_best=True.

Parameters:

Name Type Description Default
x Tensor

Training features of shape (N, n_inputs).

required
y Tensor

Training targets of shape (N,).

required
epochs int

Maximum number of training epochs.

200
learning_rate float

Learning rate for the default AdamW optimizer.

0.001
criterion Callable[[Tensor, Tensor], Tensor] | None

Optional loss function. Defaults to :meth:_default_criterion.

None
optimizer torch.optim.Optimizer | None

Optional pre-built optimizer. When None, AdamW is constructed with separate parameter groups for antecedent (no weight decay) and consequent (weight_decay) layers.

None
batch_size int | None

Mini-batch size. None uses the full dataset.

None
shuffle bool

If True, reshuffle sample indices each epoch.

True
ur_weight float

Non-negative weight for the uniform rule regularization term. 0.0 disables it.

0.0
ur_target float | None

Target uniform activation for UR. Must be in (0, 1] when provided. None defaults to 1 / n_rules.

None
verbose bool | int

Verbosity level. 0 = quiet, 1 = progress bar, 2 = per-epoch summary logging, 3 = per-epoch detailed logging. True is accepted as an alias for 2.

False
x_val Tensor | None

Optional validation features of shape (M, n_inputs).

None
y_val Tensor | None

Optional validation targets of shape (M,).

None
patience int | None

Number of consecutive epochs without improvement before early stopping. Set to None to disable early stopping. Only active when x_val and y_val are given.

20
restore_best bool

If True (default), restore the model weights from the best validation epoch when early stopping is used.

True
weight_decay float

L2 weight decay applied to consequent parameters by the default AdamW optimizer.

1e-08

Returns:

Type Description
dict[str, Any]

A dictionary with keys "train", "ur", and "val"

dict[str, Any]

containing per-epoch loss lists.

Raises:

Type Description
ValueError

If shapes of x, y, x_val, or y_val are incompatible, or if ur_weight < 0 or ur_target is outside (0, 1].

Source code in highfis/base.py
def fit(
    self,
    x: Tensor,
    y: Tensor,
    epochs: int = 200,
    learning_rate: float = 1e-3,
    criterion: Callable[[Tensor, Tensor], Tensor] | None = None,
    optimizer: torch.optim.Optimizer | None = None,
    batch_size: int | None = None,
    shuffle: bool = True,
    ur_weight: float = 0.0,
    ur_target: float | None = None,
    verbose: bool | int = False,
    x_val: Tensor | None = None,
    y_val: Tensor | None = None,
    patience: int | None = 20,
    restore_best: bool = True,
    weight_decay: float = 1e-8,
) -> dict[str, Any]:
    """Train the model with optional early stopping.

    When *x_val* and *y_val* are provided the model evaluates a
    task-specific metric (via :meth:`_evaluate_validation`) after every
    epoch and applies early stopping when the metric has not improved for
    *patience* consecutive epochs.
    By default the best model weights from validation are restored when
    ``restore_best=True``.

    Args:
        x: Training features of shape ``(N, n_inputs)``.
        y: Training targets of shape ``(N,)``.
        epochs: Maximum number of training epochs.
        learning_rate: Learning rate for the default AdamW optimizer.
        criterion: Optional loss function.  Defaults to
            :meth:`_default_criterion`.
        optimizer: Optional pre-built optimizer.  When ``None``, AdamW
            is constructed with separate parameter groups for antecedent
            (no weight decay) and consequent (*weight_decay*) layers.
        batch_size: Mini-batch size.  ``None`` uses the full dataset.
        shuffle: If ``True``, reshuffle sample indices each epoch.
        ur_weight: Non-negative weight for the uniform rule
            regularization term.  ``0.0`` disables it.
        ur_target: Target uniform activation for UR.  Must be in
            ``(0, 1]`` when provided.  ``None`` defaults to
            ``1 / n_rules``.
        verbose: Verbosity level. ``0`` = quiet, ``1`` = progress bar,
            ``2`` = per-epoch summary logging, ``3`` = per-epoch detailed
            logging. ``True`` is accepted as an alias for ``2``.
        x_val: Optional validation features of shape
            ``(M, n_inputs)``.
        y_val: Optional validation targets of shape ``(M,)``.
        patience: Number of consecutive epochs without improvement
            before early stopping.  Set to ``None`` to disable early
            stopping.  Only active when *x_val* and *y_val* are given.
        restore_best: If ``True`` (default), restore the model weights
            from the best validation epoch when early stopping is used.
        weight_decay: L2 weight decay applied to consequent parameters
            by the default AdamW optimizer.

    Returns:
        A dictionary with keys ``"train"``, ``"ur"``, and ``"val"``
        containing per-epoch loss lists.

    Raises:
        ValueError: If shapes of *x*, *y*, *x_val*, or *y_val* are
            incompatible, or if *ur_weight* < 0 or *ur_target* is
            outside ``(0, 1]``.
    """
    if x.ndim != 2 or x.shape[1] != self.n_inputs:
        raise ValueError(f"expected x shape (batch, {self.n_inputs}), got {tuple(x.shape)}")
    if y.ndim != 1:
        raise ValueError("expected y shape (batch,)")
    if ur_weight < 0.0:
        raise ValueError("ur_weight must be >= 0")
    if ur_target is not None and not (0.0 < ur_target <= 1.0):
        raise ValueError("ur_target must be in (0, 1] when provided")

    has_val = x_val is not None and y_val is not None
    if has_val:
        if x_val is None or y_val is None:  # pragma: no cover
            raise ValueError("x_val and y_val must both be provided")
        if x_val.ndim != 2 or x_val.shape[1] != self.n_inputs:
            raise ValueError(f"expected x_val shape (batch, {self.n_inputs}), got {tuple(x_val.shape)}")
        if y_val.ndim != 1:
            raise ValueError("expected y_val shape (batch,)")

    train_criterion = criterion or self._default_criterion()
    if optimizer is not None:
        train_optimizer = optimizer
    else:
        ante_params = list(self.membership_layer.parameters())
        rule_params = list(self.rule_layer.parameters())
        cons_params = list(self.consequent_layer.parameters())
        if self.consequent_bn is not None:
            cons_params.extend(self.consequent_bn.parameters())
        train_optimizer = torch.optim.AdamW(
            [
                {"params": ante_params, "weight_decay": 0.0},
                {"params": rule_params, "weight_decay": 0.0},
                {"params": cons_params, "weight_decay": weight_decay},
            ],
            lr=learning_rate,
        )

    history: dict[str, Any] = {"train": [], "ur": [], "val": []}
    best_metric = float("-inf")
    epochs_no_improve = 0
    best_state: dict[str, Any] | None = None
    verbose_level = self._resolve_verbose(verbose)

    self.train()
    pbar = None
    if verbose_level == 1:
        pbar = trange(epochs, desc="Training", leave=False)
        epoch_iterator = pbar
    else:
        epoch_iterator = range(epochs)

    for epoch in epoch_iterator:
        batch_losses: list[float] = []
        batch_ur_losses: list[float] = []
        for batch_idx in _iter_minibatch_indices(x.shape[0], batch_size=batch_size, shuffle=shuffle):
            x_b = x.index_select(0, batch_idx.to(device=x.device))
            y_b = y.index_select(0, batch_idx.to(device=y.device))

            train_optimizer.zero_grad(set_to_none=True)
            output, norm_w = self._forward_train(x_b)
            main_loss = self._compute_loss(train_criterion, output, y_b)

            ur_loss = _uniform_regularization_loss(norm_w, target=ur_target)
            loss = main_loss + (float(ur_weight) * ur_loss)
            loss.backward()
            train_optimizer.step()

            batch_losses.append(float(loss.detach().item()))
            batch_ur_losses.append(float(ur_loss.detach().item()))

        epoch_train_loss = float(sum(batch_losses) / max(len(batch_losses), 1))
        history["train"].append(epoch_train_loss)
        history["ur"].append(float(sum(batch_ur_losses) / max(len(batch_ur_losses), 1)))

        # --- validation & early stopping ---
        if has_val and x_val is not None and y_val is not None:
            self.eval()
            val_info = self._evaluate_validation(train_criterion, x_val, y_val)
            history["val"].append(val_info.get("val_loss", 0.0))
            # Store any extra keys (e.g. val_acc) in history
            for k, v in val_info.items():
                if k not in ("val_loss", "metric"):
                    history.setdefault(k, []).append(v)
            self.train()

            metric = val_info["metric"]
            if metric > best_metric:
                best_metric = metric
                epochs_no_improve = 0
                best_state = copy.deepcopy(self.state_dict())
            else:
                epochs_no_improve += 1

            if verbose_level == 1:
                if pbar is None:  # pragma: no cover
                    raise RuntimeError("progress bar unavailable for verbose level 1")
                postfix = [
                    f"train={epoch_train_loss:.4f}",
                    f"val={val_info.get('val_loss', 0.0):.4f}",
                ]
                pbar.set_postfix_str(" ".join(postfix))
            if verbose_level >= 2 and (
                verbose_level == 3 or ((epoch + 1) % max(epochs // 10, 1) == 0 or epoch == 0)
            ):
                log_parts = [
                    f"epoch={epoch + 1}/{epochs}",
                    f"train_loss={epoch_train_loss:.6f}",
                ]
                for k, v in val_info.items():
                    if k != "metric":
                        log_parts.append(f"{k}={v:.6f}" if isinstance(v, float) else f"{k}={v}")
                self._log(" ".join(log_parts), verbose=verbose_level)

            if patience is not None and epochs_no_improve >= patience:
                if verbose_level >= 2:
                    self._log(
                        "early stopping at epoch %s (patience=%s)",
                        epoch + 1,
                        patience,
                        verbose=verbose_level,
                    )
                break
        else:
            if verbose_level == 1:
                if pbar is None:  # pragma: no cover
                    raise RuntimeError("progress bar unavailable for verbose level 1")
                pbar.set_postfix_str(f"loss={epoch_train_loss:.4f}")
            if verbose_level >= 2 and (
                verbose_level == 3 or ((epoch + 1) % max(epochs // 10, 1) == 0 or epoch == 0)
            ):
                self._log(
                    "epoch=%s/%s loss=%.6f",
                    epoch + 1,
                    epochs,
                    epoch_train_loss,
                    verbose=verbose_level,
                )

    if pbar is not None:
        pbar.close()

    if restore_best and best_state is not None:
        self.load_state_dict(best_state)

    history["stopped_epoch"] = epoch + 1  # type: ignore[possibly-undefined]

    return history

forward

Full forward pass through the TSK pipeline.

Source code in highfis/base.py
def forward(self, x: Tensor) -> Tensor:
    """Full forward pass through the TSK pipeline."""
    output, _ = self._forward_train(x)
    return output

forward_antecedents

Compute normalized rule strengths from model antecedents.

Source code in highfis/base.py
def forward_antecedents(self, x: Tensor) -> Tensor:
    """Compute normalized rule strengths from model antecedents."""
    mu = self.membership_layer(x)
    w = self.rule_layer(mu)
    return cast(Tensor, self.defuzzifier(w))