verskyt.visualizations¶
Visualization tools for Tversky Neural Networks.
This module provides functions for visualizing and interpreting learned prototypes and features in TNNs. The functions are designed to make abstract concepts of “prototypes” and “features” tangible and visible for research analysis.
Module: plotting¶
Visualization tools for Tversky Neural Networks.
This module provides functions for visualizing and interpreting learned prototypes and features in TNNs. The functions are designed to make abstract concepts of “prototypes” and “features” tangible and visible for research analysis.
- plot_prototype_space(prototypes: Tensor, prototype_labels: List[str], features: Tensor | None = None, feature_labels: List[str] | None = None, reduction_method: str = 'pca', title: str = 'Learned Prototype Space', ax: Axes | None = None) Axes[source]¶
Visualizes high-dimensional prototypes and features in a 2D space.
This function uses dimensionality reduction to project high-dimensional prototype and feature vectors into 2D space for visualization. This helps researchers understand the conceptual relationships the model has learned.
- Parameters:
prototypes (torch.Tensor) – The learned prototype vectors of shape [num_prototypes, embedding_dim].
prototype_labels (List[str]) – A list of names for each prototype.
features (Optional[torch.Tensor]) – Optional feature vectors to plot, e.g., from a grounded feature bank. Shape [num_features, embedding_dim].
feature_labels (Optional[List[str]]) – Optional names for each feature vector.
reduction_method (str) – ‘pca’ or ‘tsne’ for dimensionality reduction. Defaults to ‘pca’.
title (str) – The title of the plot. Defaults to “Learned Prototype Space”.
ax (Optional[plt.Axes]) – A matplotlib axes object to plot on. If None, a new figure and axes are created.
- Returns:
The matplotlib axes object with the plot.
- Return type:
plt.Axes
- Raises:
ValueError – If reduction_method is not ‘pca’ or ‘tsne’.
Note
This visualization is particularly useful for understanding the conceptual structure learned by TNNs, as described in Doumbouya et al. (2025). PCA preserves global structure while t-SNE is better for local clustering.
Example
>>> prototypes = torch.randn(3, 128) >>> labels = ["Low-Risk", "Medium-Risk", "High-Risk"] >>> ax = plot_prototype_space(prototypes, labels) >>> plt.show()
- visualize_prototypes_as_data(encoder: Module, prototypes: Tensor, prototype_labels: List[str], dataloader: DataLoader, top_k: int = 5, device: str | device | None = None) Figure[source]¶
Visualizes prototypes by showing the top_k most similar data samples.
This function provides the most intuitive form of interpretation: showing what a prototype “looks like” by finding the real data points that are most similar to it. This approach is more general than data-domain specification and doesn’t require retraining.
- Parameters:
encoder (torch.nn.Module) – The part of the model that produces the embeddings (i.e., the layers before the TverskyProjectionLayer).
prototypes (torch.Tensor) – The learned prototype vectors of shape [num_prototypes, embedding_dim].
prototype_labels (List[str]) – A list of names for each prototype.
dataloader (torch.utils.data.DataLoader) – A dataloader for the dataset (preferably the training set) with shuffle=False.
top_k (int) – The number of data samples to show for each prototype. Defaults to 5.
device (Optional[Union[str, torch.device]]) – The device to run computations on. If None, uses the same device as prototypes.
- Returns:
The matplotlib figure containing the visualization.
- Return type:
plt.Figure
Note
This function uses cosine similarity to find the most similar data samples to each prototype. The visualization assumes image data with channel-first format (C, H, W) and converts to channel-last for display.
Example
>>> # Assuming 'model' is a trained TNN with encoder component >>> fig = visualize_prototypes_as_data( ... encoder=model.encoder, ... prototypes=model.tnn_layer.prototypes, ... prototype_labels=["Class 0", "Class 1"], ... dataloader=train_loader, ... top_k=3 ... ) >>> plt.show()
Functions¶
plot_prototype_space¶
- plot_prototype_space(prototypes: Tensor, prototype_labels: List[str], features: Tensor | None = None, feature_labels: List[str] | None = None, reduction_method: str = 'pca', title: str = 'Learned Prototype Space', ax: Axes | None = None) Axes[source]¶
Visualizes high-dimensional prototypes and features in a 2D space.
This function uses dimensionality reduction to project high-dimensional prototype and feature vectors into 2D space for visualization. This helps researchers understand the conceptual relationships the model has learned.
- Parameters:
prototypes (torch.Tensor) – The learned prototype vectors of shape [num_prototypes, embedding_dim].
prototype_labels (List[str]) – A list of names for each prototype.
features (Optional[torch.Tensor]) – Optional feature vectors to plot, e.g., from a grounded feature bank. Shape [num_features, embedding_dim].
feature_labels (Optional[List[str]]) – Optional names for each feature vector.
reduction_method (str) – ‘pca’ or ‘tsne’ for dimensionality reduction. Defaults to ‘pca’.
title (str) – The title of the plot. Defaults to “Learned Prototype Space”.
ax (Optional[plt.Axes]) – A matplotlib axes object to plot on. If None, a new figure and axes are created.
- Returns:
The matplotlib axes object with the plot.
- Return type:
plt.Axes
- Raises:
ValueError – If reduction_method is not ‘pca’ or ‘tsne’.
Note
This visualization is particularly useful for understanding the conceptual structure learned by TNNs, as described in Doumbouya et al. (2025). PCA preserves global structure while t-SNE is better for local clustering.
Example
>>> prototypes = torch.randn(3, 128) >>> labels = ["Low-Risk", "Medium-Risk", "High-Risk"] >>> ax = plot_prototype_space(prototypes, labels) >>> plt.show()
visualize_prototypes_as_data¶
- visualize_prototypes_as_data(encoder: Module, prototypes: Tensor, prototype_labels: List[str], dataloader: DataLoader, top_k: int = 5, device: str | device | None = None) Figure[source]¶
Visualizes prototypes by showing the top_k most similar data samples.
This function provides the most intuitive form of interpretation: showing what a prototype “looks like” by finding the real data points that are most similar to it. This approach is more general than data-domain specification and doesn’t require retraining.
- Parameters:
encoder (torch.nn.Module) – The part of the model that produces the embeddings (i.e., the layers before the TverskyProjectionLayer).
prototypes (torch.Tensor) – The learned prototype vectors of shape [num_prototypes, embedding_dim].
prototype_labels (List[str]) – A list of names for each prototype.
dataloader (torch.utils.data.DataLoader) – A dataloader for the dataset (preferably the training set) with shuffle=False.
top_k (int) – The number of data samples to show for each prototype. Defaults to 5.
device (Optional[Union[str, torch.device]]) – The device to run computations on. If None, uses the same device as prototypes.
- Returns:
The matplotlib figure containing the visualization.
- Return type:
plt.Figure
Note
This function uses cosine similarity to find the most similar data samples to each prototype. The visualization assumes image data with channel-first format (C, H, W) and converts to channel-last for display.
Example
>>> # Assuming 'model' is a trained TNN with encoder component >>> fig = visualize_prototypes_as_data( ... encoder=model.encoder, ... prototypes=model.tnn_layer.prototypes, ... prototype_labels=["Class 0", "Class 1"], ... dataloader=train_loader, ... top_k=3 ... ) >>> plt.show()
Usage Examples¶
Basic Prototype Space Visualization¶
import torch
from verskyt.visualizations import plot_prototype_space
# Assume you have trained prototypes
prototypes = model.tnn_layer.prototypes
labels = ["Low-Risk", "Medium-Risk", "High-Risk"]
# Visualize the learned prototype space
ax = plot_prototype_space(prototypes, labels)
plt.show()
Data-Based Prototype Interpretation¶
from verskyt.visualizations import visualize_prototypes_as_data
# Show which data samples are most similar to each prototype
fig = visualize_prototypes_as_data(
encoder=model.encoder,
prototypes=model.tnn_layer.prototypes,
prototype_labels=["Class 0", "Class 1"],
dataloader=train_loader,
top_k=5
)
plt.show()
Requirements¶
The visualization module requires additional dependencies that can be installed with:
pip install verskyt[visualization]
Dependencies include:
matplotlib>=3.5.0
seaborn>=0.12.0
scikit-learn>=1.1.0