Source code for

#!/usr/bin/env python
# coding: utf-8

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

[docs]def arcs2adj_mat(arcs: list, n_variables: int) -> np.array: """ This function transforms the list of arcs in the BN structure to an adjacency matrix. :param arcs: list of arcs in the BN structure. :param n_variables: number of variables. :type arcs: list :type n_variables: int :return: adjacency matrix :rtype: np.array """ matrix = np.zeros((n_variables, n_variables)) for arc in arcs: matrix[arc[0], arc[1]] = 1 return matrix
def _noise(n_variables: int, size: float) -> np.array: h_noise = np.zeros(n_variables) h_noise[::2] = size*2 return h_noise - size def _set_positions(variables: list) -> dict: n_variables = len(variables) n_cols = int(np.sqrt(n_variables)) n_rows = int(np.ceil(n_variables / n_cols)) pos_list = [] for row in range(n_rows): for col in range(n_cols): pos_list.append([col, -row]) '''if noise: noise_list = _noise(len_variables, size) for i in range(len_variables): pos_list[i][0] += noise_list[i] pos_list[i][1] += noise_list[i]''' return {variables[i]: pos_list[i] for i in range(n_variables)}
[docs]def plot_bn(arcs: list, var_names: list, pos: dict = None, curved_arcs: bool = True, curvature: float = -0.3, node_size: int = 500, node_color: str = 'red', edge_color: str = 'black', arrow_size: int = 15, node_transparency: float = 0.9, edge_transparency: float = 0.9, node_line_widths: float = 2, title: str = None, output_file: str = None): """ This function Plots a BN structure as a directed acyclic graph. :param arcs: Arcs in the BN structure. :param var_names: List of variables. :param pos: Positions in the plot for each node. :param curved_arcs: True if curved arcs are desired. :param curvature: Radians of curvature for edges. By default, -0.3. :param node_size: Size of the nodes in the graph. By default, 500. :param node_color: Color set to nodes. By default, 'red'. :param edge_color: Color set to edges. By default, 'black'. :param arrow_size: Size of arrows in edges. By default, 15. :param node_transparency: Alpha value [0, 1] that defines the transparency of the node. By default, 0.9. :param edge_transparency: Alpha value [0, 1] that defines the transparency of the edge. By default, 0.9. :param node_line_widths: Width of the nodes contour lines. By default, 2.0. :param title: Title for Figure. By default, None. :param output_file: Path to save the figure locally. :type arcs: list(tuple) :type var_names: list :type pos: dict {name of variables: tuples with coordinates} :type curved_arcs: bool :type curvature: float :type node_size: int :type node_color: str :type edge_color: str :type arrow_size: int :type node_transparency: float :type edge_transparency: float :type node_line_widths: float :type title: str :type output_file: str :return: Figure. """ g = nx.DiGraph() g.add_nodes_from(var_names) g.add_edges_from(arcs) if not pos: pos = _set_positions(var_names) nx.draw_networkx_nodes(g, pos, node_size=node_size, node_color=node_color, alpha=node_transparency, linewidths=node_line_widths) if curved_arcs: nx.draw_networkx_edges(g, pos, edgelist=g.edges(), edge_color=edge_color, connectionstyle="arc3,rad=" + str(curvature), arrowsize=arrow_size, alpha=edge_transparency) else: nx.draw_networkx_edges(g, pos, edgelist=g.edges(), edge_color=edge_color, arrowsize=arrow_size) nx.draw_networkx_labels(g, pos) plt.axis('off') if title: plt.title(title) plt.tight_layout() if output_file: plt.savefig(output_file)