import sys
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from chembee.utils.file_utils import prepare_file_name_saving
import logging
logging.basicConfig(
format="%(levelname)s:%(asctime)s %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
level=logging.INFO,
filename="chembee.plotting.log",
)
matplotlib.rcParams.update({"font.size": 32})
fig = plt.figure(figsize=(15, 15))
# TODO: Refactor the module and also split it up or distribute it into other or apply more suitable module naming
[docs]def plot_collection_stratified(metrics_json: dict, file_name: str, prefix: str) -> None:
"""
The plot_collection_stratified function plots the metrics of a collection stratified by
the number of documents in each class. The function takes as input a dictionary containing
metrics for each class, and outputs plots to file.
:param metrics_json:dict: Used to Pass the metrics json file.
:param file_name:str: Used to Specify the name of the file that will be created.
:param prefix:str: Used to Define the prefix of the output file.
:return: Nothing.
:doc-author: Trelent
"""
scalar_metrics = metrics_json["scalar"]
array_metrics = metrics_json["array"]
# TODO prepare array metrics for ROC-AUC
# matrix_metrics = metrics_json["matrix"]
# TODO perpare matrix metrics
plot_bar_chart_collection_strat(scalar_metrics, file_name, prefix)
plot_roc_chart_collection_strat(array_metrics, file_name, prefix)
[docs]def plot_collection(metrics_json: dict, file_name: str, prefix: str) -> None:
"""
The plot_collection function takes a metrics_json dictionary, and plots the scalar
metrics in a bar chart, and the array metrics in an ROC curve. The plot is saved to
a file named 'file_name'.prefix'
:param metrics_json:dict: Used to Pass the metrics data.
:param file_name:str: Used to Specify the name of the file that will be created.
:param prefix:str: Used to Specify the prefix for the output file names.
:return: A dictionary containing the file names of the plots.
:doc-author: Trelent
"""
scalar_metrics = metrics_json["scalar"]
array_metrics = metrics_json["array"]
matrix_metrics = metrics_json["matrix"]
plot_bar_chart_collection(scalar_metrics, file_name, prefix)
plot_roc_chart_collection(array_metrics, file_name, prefix)
plot_heat_map_collection(matrix_metrics, file_name, prefix)
[docs]def plot_heat_map_collection(metrics_json: dict, file_name: str, prefix: str) -> None:
"""
The plot_heat_map_collection function takes a dictionary of metrics and plots them in a heat map.
The function takes three arguments:
1) metrics_json: A dictionary containing the metric values for each algorithm, metric pair.
The keys are tuples of the form (algorithm name, metric name). The values are lists of floats
representing the value taken by that particular algorithm on that particular metric for every run.
2) file_name: A string representing what to save the resulting plot as
3) prefix: A string which is prepended to all algorithms names when plotting
:param metrics_json:dict: Used to pass the metrics data.
:param file_name:str: Used to specify the name of the file where the plot will be saved.
:param prefix:str: Used to specify the prefix of the file name.
:return: A list of the heat map images.
:doc-author: Julian M. Kleber
"""
algs, metrics, metrics_storage = init_collection_plot(
metrics_json, metric_type="scalar"
)
for metric in metrics:
metrics_storage = []
for i in range(len(algs)):
metrics_storage.append(metrics_json[algs[i]][metric])
return metrics_storage
[docs]def plot_roc_chart_collection_strat(
metrics_json: dict, file_name: str, prefix: str
) -> None:
"""
The plot_roc_chart_collection_strat function plots the ROC curves for a collection of algorithms.
It takes as input:
- metrics_json, which is a dictionary containing the performance metrics for each algorithm in the collection.
The keys are strings representing each algorithm and its value is also a dictionary containing all of its performance metrics.
- file_name, which is just the name of the file to be saved as (without any extension). The plot will be saved with an .png extension by default.
:param metrics_json:dict: Used to Pass the metrics_json:dict from the main function to the plot_roc_chart_collection function.
:param file_name:str: Used to Specify the name of the file to which we want to save our plot.
:param prefix:str: Used to Distinguish between different plots.
:return: A plot of the roc curve for each algorithm.
:doc-author: Trelent
"""
# TODO: Need to do this for multi_class classifiers, too
# Not very elegant here the numbering is bad
algs, metrics, metrics_storage = init_collection_plot(
metrics_json, metric_type="scalar"
)
for i in range(len(algs)):
fprs = metrics_json[algs[i]]["fpr"]
tprs = metrics_json[algs[i]]["tpr"]
roc_aucs = metrics_json[algs[i]]["roc_auc"]
avg_fpr = np.mean(fprs, axis=0)
avg_tpr = np.mean(tprs, axis=0)
avg_roc_auc = np.mean(roc_aucs)
std_roc_auc = np.std(roc_aucs)
std_tpr = np.std(tprs, axis=0)
std_fpr = np.std(fprs, axis=0)
# fprs = np.array(metrics_json[algs[i]]["fpr"]).reshape(len(metrics_json[algs[i]]["fpr"]), len(metrics_json[algs[i]]["fpr"][0]))
avg_fpr = np.mean(np.array(metrics_json[algs[i]]["fpr"]), axis=0)
plot_roc_chart_strat(
avg_fpr=avg_fpr,
avg_tpr=avg_tpr,
avg_roc_auc=avg_roc_auc,
std_tpr=std_tpr,
std_fpr=std_fpr,
std_roc_auc=std_roc_auc,
file_name="roc_auc_stratified_" + algs[i] + "_" + file_name,
prefix=prefix,
)
[docs]def plot_roc_chart_collection(metrics_json: dict, file_name: str, prefix: str) -> None:
"""
The plot_roc_chart_collection function plots the ROC curves for a collection of algorithms.
The function takes as input a dictionary containing the metrics for each algorithm, and an output file name.
It then plots all of the ROC curves on one plot, with each curve labeled by its corresponding algorithm.
:param metrics_json:dict: Used to sass the metrics_json:dict from the main function to this function.
:param file_name:str: Used to specify the name of the file to be saved.
:param prefix:str: Used to Add a prefix to the name of the file.
:return: The metrics_storage variable.
:doc-author: Julian M. Kleber"""
# TODO: Need to do this for multi_class classifiers, too
# Not very elegant here the numbering is bad
algs, metrics, metrics_storage = init_collection_plot(
metrics_json, metric_type="scalar"
)
for i in range(len(algs)):
plot_roc_chart(
metrics_json[algs[i]]["fpr"],
metrics_json[algs[i]]["tpr"],
metrics_json[algs[i]]["roc_auc"],
file_name="roc_auc_" + algs[i] + "_" + file_name,
prefix=prefix,
)
[docs]def plot_bar_chart_collection_strat(
scalar_metrics: dict, file_name: str, prefix: str, *args
) -> None:
"""
The plot_bar_chart_collection function plots a collection of bar charts, one for each metric in the metrics_json.
The metrics_json is expected to be a dictionary with keys corresponding to algorithm names and values being dictionaries
of scalar values (one value per metric). The file name is expected to be the same for all plots.
:param metrics_json:dict: Used to Store the values of all metrics for each algorithm.
:param file_name:str: Used to Specify the name of the file that is generated by this function.
:param prefix:str: Used to Distinguish between different types of plots.
:return: The metrics_storage array.
:doc-author: Julian M. Kleber
"""
algorithms = list(scalar_metrics.keys())
metrics = list(scalar_metrics[algorithms[0]].keys())
for metric in metrics:
scalar_averages = []
scalar_stds = []
for alg in algorithms: # one metrics one algorithm
scalar_averages.append(np.mean(scalar_metrics[alg][metric]))
scalar_stds.append(np.std(scalar_metrics[alg][metric]))
plot_bar_chart(
algorithms,
scalar_averages,
yerr=scalar_stds,
y_label=metric,
file_name=metric + "_stratified_" + file_name,
prefix=prefix,
) # one metric all algorithms
logging.info("plotted bar chart for the metric " + str(metric))
plt.cla()
plt.clf()
plt.close()
[docs]def plot_bar_chart_collection(
metrics_json: dict, file_name: str, prefix: str, *args
) -> None:
"""
The plot_bar_chart_collection function plots a collection of bar charts, one for each metric in the metrics_json.
The metrics_json is expected to be a dictionary with keys corresponding to algorithm names and values being dictionaries
of scalar values (one value per metric). The file name is expected to be the same for all plots.
:param metrics_json:dict: Used to Store the values of all metrics for each algorithm.
:param file_name:str: Used to Specify the name of the file that is generated by this function.
:param prefix:str: Used to Distinguish between different types of plots.
:return: The metrics_storage array.
:doc-author: Julian M. Kleber
"""
matplotlib.rcParams.update({"font.size": 52})
algs, metrics, metrics_storage = init_collection_plot(
metrics_json, metric_type="scalar"
)
metrics_store = np.zeros((len(metrics), len(algs)))
for i in range(len(metrics)):
metric = metrics[i]
for j in range(len(algs)):
try:
metrics_storage[i, j] = metrics_json[algs[j]][metric]
except:
logging.error(
"Either value for alg:"
+ str(algs[j])
+ " or metric:"
+ str(metric)
+ " is not a valid value"
)
plot_bar_chart(
algs,
metrics_storage[i, :],
y_label=metric,
file_name=metric + "_" + file_name,
prefix=prefix,
)
logging.info(
"plotted bar chart for " +
str(algs[j] + " and the metric " + str(metric))
)
plt.cla()
plt.clf()
plt.close()
[docs]def plot_roc_chart_strat(
avg_fpr: list,
avg_tpr: list,
avg_roc_auc: list,
std_tpr: list,
std_fpr: list,
std_roc_auc: float,
prefix: str,
file_name: str = "roc_auc_curve",
) -> None:
"""
The plot_roc_chart_strat function plots the ROC curve for a given model.
It takes as input:
- avg_fpr, which is an array of false positive rates calculated from multiple
iterations of the model on different test sets. The length of this array
should be equal to the number of iterations you ran your model on.
- avg_tpr, which is an array containing true positive rates calculated from
multiple iterations of the model on different test sets. The length
should be equal to that of avg_fpr and again corresponds to how many times you ran your model.
- std_tpr, which is an array containing standard deviations for each value.
- std_fpr, which contains standard deviations for each value in
- prefix: a string that will appear at the beginning (before) any file name generated by this function; it can simply be appended with 'prefix' when calling this function if desired but it's not required as such
- file name: a string that will serve as part (after) prefix + _rocauc plot title + '.png'
:param avg_fpr:list: Used to Plot the average false positive rate (x-axis).
:param avg_tpr:list: Used to Plot the mean curve.
:param avg_roc_auc:list: Used to Store the average roc_auc score for each fold.
:param std_tpr:list: Used to Calculate the upper and lower bound of the area under curve.
:param std_fpr:list: Used to Plot the standard deviation of the roc curve.
:param std_roc_auc:float: Used to Calculate the standard deviation of the roc-auc.
:param prefix:str: Used to Add a prefix to the file name.
:param file_name:str="roc_auc_curve": Used to Save the figure.
:return: None.
:doc-author: Julian M. Kleber
"""
"Needs x,y data, check here for more information"
matplotlib.rcParams.update({"font.size": 15})
fig = plt.figure(figsize=(15, 15))
lw = 2.5
tprs_lower = 1 - np.maximum(avg_tpr - std_tpr, 0)
tprs_upper = 1 - np.minimum(avg_tpr + std_tpr, 1)
fig, ax = plt.subplots()
ax.plot(
1 - avg_fpr, # compare test with standard function
1 - avg_tpr,
lw=lw,
color="black",
label=r"Mean ROC (AUC = %0.2f $\pm$ %0.2f)" % (
avg_roc_auc, std_roc_auc),
)
ax.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend(loc="lower right")
ax.fill_between(
1 - avg_fpr,
tprs_lower,
tprs_upper,
color="grey",
alpha=0.2,
label=r"$\pm$ 1 std. dev.",
)
fig.tight_layout()
file_name = prepare_file_name_saving(
prefix=prefix, file_name=file_name, ending=".png"
)
fig.savefig(file_name)
plt.clf()
plt.cla()
plt.close()
[docs]def plot_roc_chart(
fpr: list, tpr: list, roc_auc: list, prefix: str, file_name: str = "roc_auc_curve"
) -> None:
"""
The plot_roc_chart function takes in a list of false positive rates, true positive rates, and the area under the curve.
It then plots these values on a graph and saves it to file_name.png
:param fpr:list: Used to Plot the false positive rate.
:param tpr:list: Used to Plot the true positive rate.
:param roc_auc:list: Used to Plot the area under the curve.
:param prefix:str: Used to Add a prefix to the file name.
:param file_name:str="roc_auc_curve": Used to Specify the name of the file to be saved.
:return: The area under the curve (auc).
:doc-author: Trelent
"""
"Needs x,y data, check here for more information"
matplotlib.rcParams.update({"font.size": 32})
fig = plt.figure(figsize=(15, 15))
lw = 4
plt.plot(
1 - np.array(fpr), # compare test with standard function
1 - np.array(tpr),
lw=lw,
label="ROC curve %s" % str(roc_auc),
)
plt.plot([0, 1], [0, 1], color="navy", lw=lw, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend(loc="lower right")
fig.tight_layout()
file_name = prepare_file_name_saving(
prefix=prefix, file_name=file_name, ending=".png"
)
plt.savefig(file_name)
plt.clf()
plt.cla()
plt.close()
[docs]def plot_heat_map():
pass
[docs]def plot_grouped_bar_chart(
labels: list,
data_names: list,
data: np.ndarray,
file_name: str,
prefix: str,
y_label: str = "Score",
) -> None:
"""
The plot_grouped_bar_chart function creates a grouped bar chart.
The function takes the following parameters:
- labels: list of strings, each string is a label for one group of bars;
- data_names: list of strings, each string is a name for one bar in the plot;
- data: numpy array with shape (len(labels), len(data_names)), contains values to be plotted as bars;
- file_name (optional): str, name under which the plot will be saved. If not specified it will not save anything;
prefix (optional): str, prefix that should be added to all files created by this function. If not specified no prefix will be added.
:param labels:list: Used to Set the labels for each bar.
:param data_names:list: Used to Label the bars.
:param data:np.ndarray: Used to Pass the data to be plotted.
:param file_name:str: Used to Specify the name of the file to be saved.
:param prefix:str: Used to Identify the type of data that is being plotted.
:param y_label:str="Score": Used to Set the y label of the plot.
:return: A plot of the data.
:doc-author: Trelent
"""
# the problem here is that labels, and data entities are not the same. But viz is necessary to transport information. Anyhow, I need those ROC curves.
file_name = prepare_plot_file_name(prefix, file_name)
matplotlib.rcParams.update({"font.size": 22})
fig = plt.figure(figsize=(15, 15))
x = np.arange(len(data_names)) # the label locations
width = 0.8 / len(data_names) # the width of the bars
fig, ax = plt.subplots()
for i in range(len(labels)):
rects = ax.bar(x - width / len(labels),
data[i, :], width, label=labels[i])
ax.bar_label(rects1, padding=len(labels) + 1)
# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel("Scores")
ax.set_xticks(x, labels)
ax.legend()
fig.tight_layout()
plt.savefig(file_name)
plt.cla()
plt.clf()
plt.close()
[docs]def plot_bar_chart(
algs: list,
metrics: list,
file_name: str,
prefix: str,
y_label: str,
yerr=None,
*args
) -> None:
"""
The plot_bar_chart function creates a bar chart of the given data.
:param algs:list: Used to Specify the algorithms to be compared.
:param metrics:list: Used to Specify the metrics that are plotted.
:param file_name:str: Used to Save the plot in a file.
:param prefix:str: Used to Specify the prefix of the file name.
:param y_label:str: Used to Specify the label of the y-axis.
:param yerr=None: Used to Specify if the bars should have errorbars.
:return: None.
:doc-author: Trelent
"""
y_label = y_label
matplotlib.rcParams.update({"font.size": 38})
fig = plt.figure(figsize=(15, 15))
elinewidth = 4
plt.bar(algs, metrics, width=0.5)
if yerr:
if len(algs) < 4:
elinewidth = 5
plt.errorbar(
algs, metrics, yerr=yerr, elinewidth=elinewidth, fmt="o", color="black"
)
file_name = prepare_file_name_saving(
prefix=prefix, file_name=file_name, ending=".png"
)
plt.ylim((0, 1))
plt.ylabel(y_label)
fig.tight_layout()
fig.savefig(file_name)
plt.clf()
plt.cla()
plt.close()
logging.info("Plotted " + str(file_name))
[docs]def init_collection_plot(metrics_json, metric_type="") -> tuple:
"""
The init_collection_plot function initializes the plot for a collection of algorithms.
It takes as input a metrics_json file and returns an array of values that will be plotted.
:param metrics_json: Used to Store the metrics of each algorithm.
:param metric_type="": Used to Specify the type of metric that we are plotting.
:return: The metrics_storage variable.
:doc-author: Trelent
"""
algs, metrics = parse_metrics_output(metrics_json)
metrics_storage = np.zeros((len(metrics), len(algs)))
return algs, metrics, metrics_storage
[docs]def parse_metrics_output(metrics_collection: dict) -> dict:
"""
The parse_metrics_output function takes a dictionary of metrics and returns a dictionary of dictionaries.
The outermost key is the algorithm name, and the value is another dictionary with keys 'precision', 'recall',
and 'fscore'. The values are lists containing precision, recall, and fscore for each class in order.
:param metrics_collection:dict: Used to Store the metrics of each algorithm.
:return: A dictionary of dictionaries.
:doc-author: Trelent
"""
algs = list(metrics_collection.keys())
metrics = list(metrics_collection[algs[0]].keys())
return algs, metrics
[docs]def plot_combined_bar_chart(scalar_metrics: dict, file_name: str):
metric_value_averages = []
metric_names = list(scalar_metrics[list(scalar_metrics.keys())[0]].keys())
clf_names = []
std_vals = []
for alg, metrics in scalar_metrics.items():
clf_names.append(alg)
tmp_std = []
tmp_avg = []
for metric_name, metric_values in metrics.items():
tmp_std.append(np.std(metric_values))
tmp_avg.append(np.mean(metric_values))
std_vals.append(tmp_std)
metric_value_averages.append(tmp_avg)
make_grouped_bar_chart(
metric_values=metric_value_averages,
metric_names=metric_names,
x_labels=clf_names,
std_vals=std_vals,
file_name=file_name,
width=0.15,
)
[docs]def make_grouped_bar_chart(
metric_values: list,
metric_names: list,
x_labels: list,
std_vals: str,
file_name: str,
width=0.1,
*args,
**kwargs
):
matplotlib.rcParams.update({"font.size": 18})
fig = plt.figure(figsize=(20, 20))
x = np.arange(len(x_labels)) # the label locations
fig, ax = plt.subplots()
metric_values = np.array(metric_values)
std_vals = np.array(std_vals)
bars1 = metric_values[:, 0]
bars2 = metric_values[:, 1]
bars3 = metric_values[:, 2]
bars4 = metric_values[:, 3]
bars5 = metric_values[:, 4]
# Set position of bar on X axis
r1 = np.arange(len(bars1))
r2 = [x + width for x in r1]
r3 = [x + width for x in r2]
r4 = [x + width for x in r3]
r5 = [x + width for x in r4]
rects1 = plt.bar(
r1, bars1, width, label=metric_names[0], yerr=std_vals[:,
0], color="#494848"
)
rects2 = plt.bar(
r2, bars2, width, label=metric_names[1], yerr=std_vals[:,
1], color="#636363"
)
rects3 = plt.bar(
r3, bars3, width, label=metric_names[2], yerr=std_vals[:,
2], color="#909090"
)
rects4 = plt.bar(
r4, bars4, width, label=metric_names[3], yerr=std_vals[:,
3], color="#B4B4B4"
)
rects5 = plt.bar(
r5, bars5, width, label=metric_names[4], yerr=std_vals[:,
4], color="#D4D4D4"
)
# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel("Score")
# assert 1 == 2, str(x_labels)+str(bars1)+str(metric_values)
plt.xticks([r + width for r in range(len(bars1))], x_labels)
plt.legend(prop={"size": 12})
fig.tight_layout()
fig.savefig(file_name, dpi=300)
plt.clf()
plt.cla()
if __name__ == "__main__":
sys.path.insert(0, os.path.abspath(
os.path.join(os.path.dirname(__file__), "..")))
from datasets.BreastCancer import BreastCancerDataset
from actions.evaluation import screen_classifier_for_metrics
from plotting.evaluation import plot_barchart_collection
DataSet = BreastCancerDataset(split_ratio=0.8)
metrics = screen_classifier_for_metrics(
X_train=DataSet.X_train,
X_test=DataSet.X_test,
y_train=DataSet.y_train,
y_test=DataSet.y_test,
)
plot_barchart_collection(
metrics, file_name=DataSet.name + "_evaluation", prefix="plots/evaluation/"
)