"""Simple plotting functions for 21cmFAST objects."""
from __future__ import annotations
import matplotlib.pyplot as plt
import numpy as np
from astropy import units as un
from astropy.cosmology import z_at_value
from matplotlib import colormaps, colors
from matplotlib.ticker import AutoLocator
from typing import Optional, Union
from . import outputs
from .outputs import Coeval, LightCone
eor_colour = colors.LinearSegmentedColormap.from_list(
"EoR",
[
(0, "white"),
(0.21, "yellow"),
(0.42, "orange"),
(0.63, "red"),
(0.86, "black"),
(0.9, "blue"),
(1, "cyan"),
],
)
colormaps.register(cmap=eor_colour)
def _imshow_slice(
cube,
slice_axis=-1,
slice_index=0,
fig=None,
ax=None,
fig_kw=None,
cbar=True,
cbar_horizontal=False,
rotate=False,
cmap="EoR",
log: bool = False,
**imshow_kw,
):
"""
Plot a slice of some kind of cube.
Parameters
----------
cube : nd-array
A 3D array of some quantity.
slice_axis : int, optional
The axis over which to take a slice, in order to plot.
slice_index :
The index of the slice.
fig : Figure object
An optional matplotlib figure object on which to plot
ax : Axis object
The matplotlib axis object on which to plot (created by default).
fig_kw :
Optional arguments passed to the figure construction.
cbar : bool
Whether to plot the colorbar
cbar_horizontal : bool
Whether the colorbar should be horizontal underneath the plot.
rotate : bool
Whether to rotate the plot vertically.
imshow_kw :
Optional keywords to pass to :func:`maplotlib.imshow`.
Returns
-------
fig, ax :
The figure and axis objects from matplotlib.
"""
# If no axis is passed, create a new one
# This allows the user to add this plot into an existing grid, or alter it afterwards.
if fig_kw is None:
fig_kw = {}
if ax is None and fig is None:
fig, ax = plt.subplots(1, 1, **fig_kw)
elif ax is None:
ax = plt.gca()
elif fig is None:
fig = plt.gcf()
plt.sca(ax)
if slice_index >= cube.shape[slice_axis]:
raise IndexError(
"slice_index is too large for that axis (slice_index=%s >= %s"
% (slice_index, cube.shape[slice_axis])
)
slc = np.take(cube, slice_index, axis=slice_axis)
if not rotate:
slc = slc.T
if cmap == "EoR":
imshow_kw["vmin"] = -150
imshow_kw["vmax"] = 30
norm_kw = {k: imshow_kw.pop(k) for k in ["vmin", "vmax"] if k in imshow_kw}
norm = imshow_kw.get(
"norm", colors.LogNorm(**norm_kw) if log else colors.Normalize(**norm_kw)
)
plt.imshow(slc, origin="lower", cmap=cmap, norm=norm, **imshow_kw)
if cbar:
cb = plt.colorbar(
orientation="horizontal" if cbar_horizontal else "vertical", aspect=40
)
cb.outline.set_edgecolor(None)
return fig, ax
[docs]
def coeval_sliceplot(
struct: outputs._OutputStruct | Coeval,
kind: str | None = None,
cbar_label: str | None = None,
**kwargs,
):
"""
Show a slice of a given coeval box.
Parameters
----------
struct : :class:`~outputs._OutputStruct` or :class:`~wrapper.Coeval` instance
The output of a function such as `ionize_box` (a class containing several quantities), or
`run_coeval`.
kind : str
The quantity within the structure to be shown. A full list of available options
can be obtained by running ``Coeval.get_fields()``.
cbar_label : str, optional
A label for the colorbar. Some values of `kind` will have automatically chosen
labels, but these can be turned off by setting ``cbar_label=''``.
Returns
-------
fig, ax :
figure and axis objects from matplotlib
Other Parameters
----------------
All other parameters are passed directly to :func:`_imshow_slice`. These include `slice_axis`
and `slice_index`,
which choose the actual slice to plot, optional `fig` and `ax` keywords which enable
over-plotting previous figures,
and the `imshow_kw` argument, which allows arbitrary styling of the plot.
"""
if kind is None:
if isinstance(struct, outputs._OutputStruct):
kind = struct.fieldnames[0]
elif isinstance(struct, Coeval):
kind = "brightness_temp"
try:
cube = getattr(struct, kind)
except AttributeError:
raise AttributeError(
f"The given OutputStruct does not have the quantity {kind}"
)
if kind != "brightness_temp" and "cmap" not in kwargs:
kwargs["cmap"] = "viridis"
fig, ax = _imshow_slice(cube, extent=(0, struct.user_params.BOX_LEN) * 2, **kwargs)
slice_axis = kwargs.get("slice_axis", -1)
# Determine which axes are being plotted.
if slice_axis in (2, -1):
xax = "x"
yax = "y"
elif slice_axis == 1:
xax = "x"
yax = "z"
elif slice_axis == 0:
xax = "y"
yax = "z"
else:
raise ValueError("slice_axis should be between -1 and 2")
# Now put on the decorations.
ax.set_xlabel(f"{xax}-axis [Mpc]")
ax.set_ylabel(f"{yax}-axis [Mpc]")
cbar = fig._gci().colorbar
if cbar is not None:
if cbar_label is None:
if kind == "brightness_temp":
cbar_label = r"Brightness Temperature, $\delta T_B$ [mK]"
elif kind == "xH_box":
cbar_label = r"Neutral fraction"
cbar.ax.set_ylabel(cbar_label)
return fig, ax
[docs]
def lightcone_sliceplot(
lightcone: LightCone,
kind: str = "brightness_temp",
lightcone2: LightCone = None,
vertical: bool = False,
xlabel: str | None = None,
ylabel: str | None = None,
cbar_label: str | None = None,
zticks: str = "redshift",
fig: plt.Figure | None = None,
ax: plt.Axes | None = None,
**kwargs,
):
"""Create a 2D plot of a slice through a lightcone.
Parameters
----------
lightcone : :class:`~py21cmfast.wrapper.Lightcone`
The lightcone object to plot
kind : str, optional
The attribute of the lightcone to plot. Must be an array.
lightcone2 : str, optional
If provided, plot the _difference_ of the selected attribute between the two
lightcones.
vertical : bool, optional
Whether to plot the redshift in the vertical direction.
cbar_label : str, optional
A label for the colorbar. Some quantities have automatically chosen labels, but
these can be removed by setting `cbar_label=''`.
zticks : str, optional
Defines the co-ordinates of the ticks along the redshift axis.
Can be "redshift" (default), "frequency", "distance" (which starts at zero
for the lowest redshift) or the name of any function in an astropy cosmology
that is purely a function of redshift.
kwargs :
Passed through to ``imshow()``.
Returns
-------
fig :
The matplotlib Figure object
ax :
The matplotlib Axis object onto which the plot was drawn.
"""
slice_axis = kwargs.pop("slice_axis", 0)
if slice_axis <= -2 or slice_axis >= 3:
raise ValueError(f"slice_axis should be between -1 and 2 (got {slice_axis})")
z_axis = ("y" if vertical else "x") if slice_axis in (0, 1) else None
# Dictionary mapping axis to dimension in lightcone
axis_dct = {
"x": 2 if z_axis == "x" else [1, 0, 0][slice_axis],
"y": 2 if z_axis == "y" else [1, 0, 1][slice_axis],
}
if fig is None and ax is None:
fig, ax = plt.subplots(
1,
1,
figsize=(
lightcone.shape[axis_dct["x"]] * 0.015 + 0.5,
lightcone.shape[axis_dct["y"]] * 0.015
+ (2.5 if kwargs.get("cbar", True) else 0.05),
),
)
elif fig is None:
fig = ax._gci().figure
elif ax is None:
ax = fig.get_axes()
# Get x,y labels if they're not the redshift axis.
if xlabel is None:
xlabel = (
None if axis_dct["x"] == 2 else "{}-axis [Mpc]".format("xy"[axis_dct["x"]])
)
if ylabel is None:
ylabel = (
None if axis_dct["y"] == 2 else "{}-axis [Mpc]".format("xy"[axis_dct["y"]])
)
extent = (
0,
lightcone.lightcone_dimensions[axis_dct["x"]],
0,
lightcone.lightcone_dimensions[axis_dct["y"]],
)
if lightcone2 is None:
fig, ax = _imshow_slice(
getattr(lightcone, kind),
extent=extent,
slice_axis=slice_axis,
rotate=not vertical,
cbar_horizontal=not vertical,
cmap=kwargs.get("cmap", "EoR" if kind == "brightness_temp" else "viridis"),
fig=fig,
ax=ax,
**kwargs,
)
else:
d = getattr(lightcone, kind) - getattr(lightcone2, kind)
fig, ax = _imshow_slice(
d,
extent=extent,
slice_axis=slice_axis,
rotate=not vertical,
cbar_horizontal=not vertical,
cmap=kwargs.pop("cmap", "bwr"),
vmin=-np.abs(d.max()),
vmax=np.abs(d.max()),
fig=fig,
ax=ax,
**kwargs,
)
if z_axis:
zlabel = _set_zaxis_ticks(ax, lightcone, zticks, z_axis)
if ylabel != "":
ax.set_ylabel(ylabel or zlabel)
if xlabel != "":
ax.set_xlabel(xlabel or zlabel)
cbar = fig._gci().colorbar
if cbar_label is None:
if kind == "brightness_temp":
cbar_label = r"Brightness Temperature, $\delta T_B$ [mK]"
elif kind == "xH":
cbar_label = r"Neutral fraction"
if vertical:
cbar.ax.set_ylabel(cbar_label)
else:
cbar.ax.set_xlabel(cbar_label)
return fig, ax
def _set_zaxis_ticks(ax, lightcone, zticks, z_axis):
if zticks != "distance":
loc = AutoLocator()
# Get redshift ticks.
lc_z = lightcone.lightcone_redshifts
if zticks == "redshift":
coords = lc_z
elif zticks == "frequency":
coords = 1420 / (1 + lc_z) * un.MHz
else:
try:
coords = getattr(lightcone.cosmo_params.cosmo, zticks)(lc_z)
except AttributeError:
raise AttributeError(f"zticks '{zticks}' is not a cosmology function.")
zlabel = " ".join(z.capitalize() for z in zticks.split("_"))
units = getattr(coords, "unit", None)
if units:
zlabel += f" [{str(coords.unit)}]"
coords = coords.value
ticks = loc.tick_values(coords.min(), coords.max())
if ticks.min() < coords.min() / 1.00001:
ticks = ticks[1:]
if ticks.max() > coords.max() * 1.00001:
ticks = ticks[:-1]
if coords[1] < coords[0]:
ticks = ticks[::-1]
if zticks == "redshift":
z_ticks = ticks
elif zticks == "frequency":
z_ticks = 1420 / ticks - 1
else:
z_ticks = [
z_at_value(getattr(lightcone.cosmo_params.cosmo, zticks), z * units)
for z in ticks
]
d_ticks = (
lightcone.cosmo_params.cosmo.comoving_distance(z_ticks)
- lightcone.lightcone_distances[0]
)
getattr(ax, f"set_{z_axis}ticks")(d_ticks.to_value("Mpc"))
getattr(ax, f"set_{z_axis}ticklabels")(ticks)
else:
zlabel = "Line-of-Sight Distance [Mpc]"
return zlabel
[docs]
def plot_global_history(
lightcone: LightCone,
kind: str | None = None,
ylabel: str | None = None,
ylog: bool = False,
ax: plt.Axes | None = None,
):
"""
Plot the global history of a given quantity from a lightcone.
Parameters
----------
lightcone : :class:`~LightCone` instance
The lightcone containing the quantity to plot.
kind : str, optional
The quantity to plot. Must be in the `global_quantities` dict in the lightcone.
By default, will choose the first entry in the dict.
ylabel : str, optional
A y-label for the plot. If None, will use ``kind``.
ax : Axes, optional
The matplotlib Axes object on which to plot. Otherwise, created.
"""
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=(7, 4))
else:
fig = ax._gci().figure
if kind is None:
kind = list(lightcone.global_quantities.keys())[0]
assert (
kind in lightcone.global_quantities
or hasattr(lightcone, "global_" + kind)
or (kind.startswith("global_") and hasattr(lightcone, kind))
)
if kind in lightcone.global_quantities:
value = lightcone.global_quantities[kind]
elif kind.startswith("global)"):
value = getattr(lightcone, kind)
else:
value = getattr(lightcone, "global_" + kind)
ax.plot(lightcone.node_redshifts, value)
ax.set_xlabel("Redshift")
if ylabel is None:
ylabel = kind
if ylabel:
ax.set_ylabel(ylabel)
if ylog:
ax.set_yscale("log")
return fig, ax