# Licensed under a 3-clause BSD style license - see LICENSE.rst
from contextlib import ContextDecorator
import numpy as np
from astropy import units as u
from astropy.utils.compat.optional_deps import HAS_MATPLOTLIB
__all__ = ["MplQuantityConverter", "quantity_support"]
__doctest_skip__ = ["quantity_support"]
_default_format = "latex_inline"
if HAS_MATPLOTLIB:
from matplotlib import ticker, units
class MplQuantityConverter(units.ConversionInterface, ContextDecorator):
"""Matplotlib converter for ``astropy.units.Quantity``.
Registers itself to matplotlib as the converter for
``astropy.units.Quantity`` when initialized. If used as a context manager,
it will restore the original converter upon exit. Also see
:meth:`quantity_support` for a convenient way to use this converter
with an optional format for the ``axisinfo``.
"""
def __init__(self):
# Keep track of original converter in case the context manager is
# used in a nested way.
self._original_converter = {u.Quantity: units.registry.get(u.Quantity)}
units.registry[u.Quantity] = self
[docs]
@staticmethod
def axisinfo(unit, axis, format=None):
"""Return a :class:`matplotlib.units.AxisInfo` for *unit* and *axis*.
Parameters
----------
unit : `~astropy.units.UnitBase`
The unit to format the axis for.
axis : `matplotlib.axis.Axis`
The matplotlib axis being formatted.
format : `astropy.units.format.Base` subclass or str or None, optional
The name of a format or a formatter class used to render the
axis label. If `None`, the module-level default
(``"latex_inline"``) is used.
"""
if format is None:
format = _default_format
if unit == u.radian:
def rad_fn(x, pos=None):
n = int((x / np.pi) * 2.0 + 0.25)
if n < 3:
return ("0", "π/2", "π")[n]
elif n % 2 == 0:
return f"{n // 2}π"
else:
return f"{n}π/2"
return units.AxisInfo(
majloc=ticker.MultipleLocator(base=np.pi / 2),
majfmt=ticker.FuncFormatter(rad_fn),
label=unit.to_string(),
)
elif unit == u.degree:
return units.AxisInfo(
majloc=ticker.AutoLocator(),
majfmt=ticker.FormatStrFormatter("%g°"),
label=unit.to_string(),
)
elif unit is not None:
return units.AxisInfo(label=unit.to_string(format))
return None
[docs]
@staticmethod
def convert(val, unit, axis):
if isinstance(val, u.Quantity):
return val.to_value(unit)
elif (
all(
hasattr(val, attr)
for attr in ["__getitem__", "__iter__", "__len__"]
)
and len(val) > 0
and isinstance(val[0], u.Quantity)
):
return np.array([v.to_value(unit) for v in val])
else:
return val
[docs]
@staticmethod
def default_units(x, axis):
if hasattr(x, "unit"):
return x.unit
elif (
all(hasattr(x, attr) for attr in ["__getitem__", "__iter__", "__len__"])
and len(x) > 0
and hasattr(x[0], "unit")
):
return x[0].unit
return None
def __enter__(self):
return self
def __exit__(self, type, value, tb):
if self._original_converter[u.Quantity] is None:
del units.registry[u.Quantity]
else:
units.registry[u.Quantity] = self._original_converter[u.Quantity]
else:
# Create mock-up class to avoid import errors when matplotlib is not available.
[docs]
class MplQuantityConverter:
def __init__(self, *args, **kwargs):
raise ImportError("matplotlib is required in order to use this class.")
[docs]
def quantity_support(format=None):
"""
Enable support for plotting :class:`astropy.units.Quantity` instances in
matplotlib.
May be (optionally) used with a ``with`` statement.
>>> import matplotlib.pyplot as plt
>>> from astropy import units as u
>>> from astropy import visualization
>>> with visualization.quantity_support():
... fig, ax = plt.subplots()
... ax.plot([1, 2, 3] * u.m)
... # doctest: +ELLIPSIS
... ax.plot([101, 125, 150] * u.cm)
... # doctest: +ELLIPSIS
... ax.yaxis.set_units(u.km)
... plt.draw()
The default axis unit is inferred from the first plot using a Quantity.
To override it, you can explicitly set the axis unit using
:meth:`matplotlib.axis.Axis.set_units`, for example,
``ax.yaxis.set_units(u.km)``.
Parameters
----------
format : :class:`astropy.units.format.Base` subclass or str or None, optional
The name of a format or a formatter class. If not
provided, defaults to ``latex_inline``.
"""
class MplQuantityConverterFormatted(MplQuantityConverter):
# Override to pass on `format`
@staticmethod
def axisinfo(unit, axis):
return MplQuantityConverter.axisinfo(unit, axis, format)
return MplQuantityConverterFormatted()