Source code for chembee.plotting.calibration

import matplotlib.pyplot as plt
import matplotlib
import logging

from chembee.utils.file_utils import get_grid_positions, prepare_file_name_saving

logging.basicConfig(
    format="%(levelname)s:%(asctime)s %(message)s",
    datefmt="%m/%d/%Y %I:%M:%S %p",
    level=logging.INFO,
    filename="chembee.plotting.log",
)


matplotlib.rcParams.update({"font.size": 32})
fig = plt.figure(figsize=(15, 15))


[docs]def plot_calibration( fig: plt.figure, clf_list: list, calibration_displays: dict, ax_calibration_curve, grid_spec, grid: tuple, colors, file_name: str = "calibration", prefix: str = "calibration/", ): ax_calibration_curve.grid() ax_calibration_curve.set_title("Calibration plots") # call other method to avoid messing it up grid_positions = get_grid_positions(rows=grid[0], cols=grid[1]) for i in range(len(clf_list)): clf = clf_list[i] name = clf.name row, col = grid_positions[i] ax = fig.add_subplot(grid_spec[row, col]) ax.hist( calibration_displays[name].y_prob, range=(0, 1), bins=100, label=name, color=colors(i), ) if i == grid[0] - 1: ax.set(xlabel="Mean predicted probability", ylabel="Count") else: ax.set(title="", xlabel="", ylabel="Count") file_name = prepare_file_name_saving(prefix, file_name, ending=".png") fig.tight_layout() plt.savefig(file_name) logging.info("Saved calibration plot " + str(file_name)) plt.cla() plt.clf()