scatter_matrix

Contents

scatter_matrix#

class braintools.visualize.scatter_matrix(data, labels=None, diagonal='hist', alpha=0.7, figsize=(12, 12), color='blue', ax=None, **kwargs)#

Create scatter plot matrix for multivariate data.

Parameters:
  • data (ndarray) – Data matrix of shape (samples, features).

  • labels (List[str] | None) – Feature labels.

  • diagonal (str) – What to plot on diagonal: ‘hist’, ‘kde’.

  • alpha (float) – Alpha transparency.

  • figsize (Tuple[float, float]) – Figure size (only used if ax is None).

  • color (str) – Plot color.

  • ax (Axes | None) – Single axes to plot a simplified version (2x2 subset). If None, creates full scatter matrix.

  • **kwargs – Additional arguments passed to scatter plots.

Returns:

fig – The figure object containing all subplots.

Return type:

Figure