verskyt.visualizations

Visualization tools for Tversky Neural Networks.

This module provides functions for visualizing and interpreting learned prototypes and features in TNNs. The functions are designed to make abstract concepts of “prototypes” and “features” tangible and visible for research analysis.

Module: plotting

Visualization tools for Tversky Neural Networks.

This module provides functions for visualizing and interpreting learned prototypes and features in TNNs. The functions are designed to make abstract concepts of “prototypes” and “features” tangible and visible for research analysis.

plot_prototype_space(prototypes: Tensor, prototype_labels: List[str], features: Tensor | None = None, feature_labels: List[str] | None = None, reduction_method: str = 'pca', title: str = 'Learned Prototype Space', ax: Axes | None = None) Axes[source]

Visualizes high-dimensional prototypes and features in a 2D space.

This function uses dimensionality reduction to project high-dimensional prototype and feature vectors into 2D space for visualization. This helps researchers understand the conceptual relationships the model has learned.

Parameters:
  • prototypes (torch.Tensor) – The learned prototype vectors of shape [num_prototypes, embedding_dim].

  • prototype_labels (List[str]) – A list of names for each prototype.

  • features (Optional[torch.Tensor]) – Optional feature vectors to plot, e.g., from a grounded feature bank. Shape [num_features, embedding_dim].

  • feature_labels (Optional[List[str]]) – Optional names for each feature vector.

  • reduction_method (str) – ‘pca’ or ‘tsne’ for dimensionality reduction. Defaults to ‘pca’.

  • title (str) – The title of the plot. Defaults to “Learned Prototype Space”.

  • ax (Optional[plt.Axes]) – A matplotlib axes object to plot on. If None, a new figure and axes are created.

Returns:

The matplotlib axes object with the plot.

Return type:

plt.Axes

Raises:

ValueError – If reduction_method is not ‘pca’ or ‘tsne’.

Note

This visualization is particularly useful for understanding the conceptual structure learned by TNNs, as described in Doumbouya et al. (2025). PCA preserves global structure while t-SNE is better for local clustering.

Example

>>> prototypes = torch.randn(3, 128)
>>> labels = ["Low-Risk", "Medium-Risk", "High-Risk"]
>>> ax = plot_prototype_space(prototypes, labels)
>>> plt.show()
visualize_prototypes_as_data(encoder: Module, prototypes: Tensor, prototype_labels: List[str], dataloader: DataLoader, top_k: int = 5, device: str | device | None = None) Figure[source]

Visualizes prototypes by showing the top_k most similar data samples.

This function provides the most intuitive form of interpretation: showing what a prototype “looks like” by finding the real data points that are most similar to it. This approach is more general than data-domain specification and doesn’t require retraining.

Parameters:
  • encoder (torch.nn.Module) – The part of the model that produces the embeddings (i.e., the layers before the TverskyProjectionLayer).

  • prototypes (torch.Tensor) – The learned prototype vectors of shape [num_prototypes, embedding_dim].

  • prototype_labels (List[str]) – A list of names for each prototype.

  • dataloader (torch.utils.data.DataLoader) – A dataloader for the dataset (preferably the training set) with shuffle=False.

  • top_k (int) – The number of data samples to show for each prototype. Defaults to 5.

  • device (Optional[Union[str, torch.device]]) – The device to run computations on. If None, uses the same device as prototypes.

Returns:

The matplotlib figure containing the visualization.

Return type:

plt.Figure

Note

This function uses cosine similarity to find the most similar data samples to each prototype. The visualization assumes image data with channel-first format (C, H, W) and converts to channel-last for display.

Example

>>> # Assuming 'model' is a trained TNN with encoder component
>>> fig = visualize_prototypes_as_data(
...     encoder=model.encoder,
...     prototypes=model.tnn_layer.prototypes,
...     prototype_labels=["Class 0", "Class 1"],
...     dataloader=train_loader,
...     top_k=3
... )
>>> plt.show()

Functions

plot_prototype_space

plot_prototype_space(prototypes: Tensor, prototype_labels: List[str], features: Tensor | None = None, feature_labels: List[str] | None = None, reduction_method: str = 'pca', title: str = 'Learned Prototype Space', ax: Axes | None = None) Axes[source]

Visualizes high-dimensional prototypes and features in a 2D space.

This function uses dimensionality reduction to project high-dimensional prototype and feature vectors into 2D space for visualization. This helps researchers understand the conceptual relationships the model has learned.

Parameters:
  • prototypes (torch.Tensor) – The learned prototype vectors of shape [num_prototypes, embedding_dim].

  • prototype_labels (List[str]) – A list of names for each prototype.

  • features (Optional[torch.Tensor]) – Optional feature vectors to plot, e.g., from a grounded feature bank. Shape [num_features, embedding_dim].

  • feature_labels (Optional[List[str]]) – Optional names for each feature vector.

  • reduction_method (str) – ‘pca’ or ‘tsne’ for dimensionality reduction. Defaults to ‘pca’.

  • title (str) – The title of the plot. Defaults to “Learned Prototype Space”.

  • ax (Optional[plt.Axes]) – A matplotlib axes object to plot on. If None, a new figure and axes are created.

Returns:

The matplotlib axes object with the plot.

Return type:

plt.Axes

Raises:

ValueError – If reduction_method is not ‘pca’ or ‘tsne’.

Note

This visualization is particularly useful for understanding the conceptual structure learned by TNNs, as described in Doumbouya et al. (2025). PCA preserves global structure while t-SNE is better for local clustering.

Example

>>> prototypes = torch.randn(3, 128)
>>> labels = ["Low-Risk", "Medium-Risk", "High-Risk"]
>>> ax = plot_prototype_space(prototypes, labels)
>>> plt.show()

visualize_prototypes_as_data

visualize_prototypes_as_data(encoder: Module, prototypes: Tensor, prototype_labels: List[str], dataloader: DataLoader, top_k: int = 5, device: str | device | None = None) Figure[source]

Visualizes prototypes by showing the top_k most similar data samples.

This function provides the most intuitive form of interpretation: showing what a prototype “looks like” by finding the real data points that are most similar to it. This approach is more general than data-domain specification and doesn’t require retraining.

Parameters:
  • encoder (torch.nn.Module) – The part of the model that produces the embeddings (i.e., the layers before the TverskyProjectionLayer).

  • prototypes (torch.Tensor) – The learned prototype vectors of shape [num_prototypes, embedding_dim].

  • prototype_labels (List[str]) – A list of names for each prototype.

  • dataloader (torch.utils.data.DataLoader) – A dataloader for the dataset (preferably the training set) with shuffle=False.

  • top_k (int) – The number of data samples to show for each prototype. Defaults to 5.

  • device (Optional[Union[str, torch.device]]) – The device to run computations on. If None, uses the same device as prototypes.

Returns:

The matplotlib figure containing the visualization.

Return type:

plt.Figure

Note

This function uses cosine similarity to find the most similar data samples to each prototype. The visualization assumes image data with channel-first format (C, H, W) and converts to channel-last for display.

Example

>>> # Assuming 'model' is a trained TNN with encoder component
>>> fig = visualize_prototypes_as_data(
...     encoder=model.encoder,
...     prototypes=model.tnn_layer.prototypes,
...     prototype_labels=["Class 0", "Class 1"],
...     dataloader=train_loader,
...     top_k=3
... )
>>> plt.show()

Usage Examples

Basic Prototype Space Visualization

import torch
from verskyt.visualizations import plot_prototype_space

# Assume you have trained prototypes
prototypes = model.tnn_layer.prototypes
labels = ["Low-Risk", "Medium-Risk", "High-Risk"]

# Visualize the learned prototype space
ax = plot_prototype_space(prototypes, labels)
plt.show()

Data-Based Prototype Interpretation

from verskyt.visualizations import visualize_prototypes_as_data

# Show which data samples are most similar to each prototype
fig = visualize_prototypes_as_data(
    encoder=model.encoder,
    prototypes=model.tnn_layer.prototypes,
    prototype_labels=["Class 0", "Class 1"],
    dataloader=train_loader,
    top_k=5
)
plt.show()

Requirements

The visualization module requires additional dependencies that can be installed with:

pip install verskyt[visualization]

Dependencies include:

  • matplotlib>=3.5.0

  • seaborn>=0.12.0

  • scikit-learn>=1.1.0