neural_network_3d

neural_network_3d#

class braintools.visualize.neural_network_3d(layer_sizes, weights=None, activations=None, layer_spacing=2.0, neuron_spacing=1.0, node_size=100, edge_alpha=0.3, ax=None, figsize=(12, 10), title=None, **kwargs)#

Visualize neural network architecture in 3D.

Parameters:
  • layer_sizes (List[int]) – Number of neurons in each layer.

  • weights (List[ndarray] | None) – Weight matrices between layers.

  • activations (List[ndarray] | None) – Neuron activations for coloring.

  • layer_spacing (float) – Spacing between layers.

  • neuron_spacing (float) – Spacing between neurons in a layer.

  • node_size (float) – Size of neuron nodes.

  • edge_alpha (float) – Alpha transparency for connections.

  • ax (Axes3D | None) – 3D axes to plot on.

  • figsize (Tuple[float, float]) – Figure size if creating new figure.

  • title (str | None) – Plot title.

  • **kwargs – Additional arguments passed to scatter.

Returns:

ax – The 3D axes object.

Return type:

Axes3D