Source code for MD_plotting_toolkit.plotting_utils

####################################################################
#                                                                  #
#    MD_plotting_toolkit,                                          #
#    a python package to visualize the results obtained from MD    #
#                                                                  #
#    Written by Wei-Tse Hsu <wehs7661@colorado.edu>                #
#    Copyright (c) 2021 University of Colorado Boulder             #
#                                                                  #
####################################################################
"""
The `plotting_utils` module provides various utilities for plotting.
"""
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rc


[docs]def default_settings(font='Arial'): """ This function adopts the plotting settings shown below. """ rc("font", **{"family": "sans-serif", "sans-serif": ["DejaVu Sans"], "size": 10}) # Set the font used for MathJax - more on this later rc("mathtext", **{"default": "regular"}) plt.rc("font", family=font)
[docs]def identify_var_units(label): """ This function parses the x-label or y-label to figure out the variable name and unit if possible. Parameters ---------- label : str The label of x- or y-axis. Returns ------- var : str The name of the variable. unit :str The units of the variable. """ if label is not None: if "(" in label: if "$" in label.split("(")[1]: unit = " ".join(label.split("$")[1].split("$")[0].split(" ")[:-1]) else: unit = " " + label.split("(")[1].split(")")[0] var = " ".join(label.split("(")[:-1]).lower() if var[-1] == " ": var = " ".join(var.split(" ")[:-1]).lower() else: unit = "" var = label.lower() else: unit = "" var = None if label == "X-axis": var = "x" if label == "Y-axis": var = "y" return var, unit
[docs]def get_fig_dimension(n_subplots): """ Gets the number of plots in each row/column given the total number of subplots. The dimensions will be as close as to a square as possible. Parameters ---------- n_subplots (int): The number of subplots. Returns ------- n_rows (int): The number of rows in the figure. n_cols (int): The number of columns in the figure. """ if int(np.sqrt(n_subplots) + 0.5) ** 2 == n_subplots: # perfect square number n_cols = int(np.sqrt(n_subplots)) else: n_cols = int(np.floor(np.sqrt(n_subplots))) + 1 if n_subplots % n_cols == 0: n_rows = int(n_subplots / n_cols) else: n_rows = int(np.floor(n_subplots / n_cols)) + 1 return n_cols, n_rows
[docs]def get_bars_locations(n_bars, n_groups, width): """ Gets the locations of all the bars in a grouped bar chart. Parameters ---------- n_bars : int The number of bars in each group. n_groups : int The total number of groups of bars. width : float The width of one bar. Returns ------- bar_locs : np.array An array of lists containing bar locations for the bars in each group. For example, if n_groups = 5, this would be an array of 5 lists. """ # We fix the center of the first group at 0. if n_bars * n_groups < 40: spacing = (n_bars * width) * 2 / 3 # 2/3 of the total width for a group else: spacing = (n_bars * width) * 1 / 5 # tighter spacing bar_locs = [] for i in range(n_groups): locs = [] left_bound = i * (n_bars * width + spacing) for j in range(1, n_bars + 1): locs.append(left_bound + 0.5 * (2 * j - 1) * width) bar_locs.append(locs) bar_locs = np.array(bar_locs) return bar_locs