neural_network_3d#
- class braintools.visualize.neural_network_3d(layer_sizes, weights=None, activations=None, layer_spacing=2.0, neuron_spacing=1.0, node_size=100, edge_alpha=0.3, ax=None, figsize=(12, 10), title=None, **kwargs)#
Visualize neural network architecture in 3D.
- Parameters:
weights (
List[ndarray] |None) – Weight matrices between layers.activations (
List[ndarray] |None) – Neuron activations for coloring.layer_spacing (
float) – Spacing between layers.neuron_spacing (
float) – Spacing between neurons in a layer.node_size (
float) – Size of neuron nodes.edge_alpha (
float) – Alpha transparency for connections.ax (
Axes3D|None) – 3D axes to plot on.figsize (
Tuple[float,float]) – Figure size if creating new figure.**kwargs – Additional arguments passed to scatter.
- Returns:
ax – The 3D axes object.
- Return type:
Axes3D