import sys from PyQt4 import QtGui from PyQt4.QtGui import QApplication, QWidget, QColorDialog, QInputDialog from PyQt4.QtCore import pyqtSlot import networkx as nx import matplotlib.colors as col import dgraph_browser from pele.utils.disconnectivity_graph import DisconnectivityGraph, database2graph from pele.storage import Database, TransitionState from pele.utils.events import Signal from pele.rates import RatesLinalg, RateCalculation, compute_committors def check_thermodynamic_info(transition_states): """return False if any transition state or minimum does not have pgorder or fvib""" def myiter(tslist): for ts in tslist: yield ts yield ts.minimum1 yield ts.minimum2 for mts in myiter(transition_states): if not mts.invalid: if mts.fvib is None or mts.pgorder is None: return False return True def minimum_energy_path_old(graph, m1, m2): """find the minimum energy path between m1 and m2 and color the dgraph appropriately""" # add weight attribute to the graph # note: this is not actually the minimum energy path. # This minimizes the sum of energies along the path # TODO: use minimum spanning tree to find the minimum energy path emin = min(( m.energy for m in graph.nodes_iter() )) for u, v, data in graph.edges_iter(data=True): data["weight"] = data["ts"].energy - emin path = nx.shortest_path(graph, m1, m2, weight="weight") return path def minimum_energy_path(graph, m1, m2): for u, v, data in graph.edges_iter(data=True): data["energy"] = data["ts"].energy mst = nx.minimum_spanning_tree(graph, weight="energy") path = nx.shortest_path(mst, m1, m2) return path # transition_states = [data["ts"] for u, v, data in graph.edges_iter(data=True)] # transition_states.sort(key=lambda ts: ts.energy) # small energies to the left # # subtrees = nx.utils.UnionFind() # for ts in transition_states: # u, v = ts.minimum1, ts.minimum2 # if subtrees[u] != subtrees[v]: # subtrees.union(u,v) # if subtrees[m1] == subtrees[m2]: # break # if subtrees class TreeLeastCommonAncestor(object): """Find the least common ancestor to a set of trees""" def __init__(self, trees): self.start_trees = trees self.run() def run(self): # find all common ancestors common_ancestors = set() for tree in self.start_trees: parents = set(tree.get_ancestors()) parents.add(tree) if len(common_ancestors) == 0: common_ancestors.update(parents) else: # remove all elements that are not common common_ancestors.intersection_update(parents) assert len(common_ancestors) > 0 if len(common_ancestors) == 0: raise Exception("the trees don't have any common ancestors") # sort the common ancestors by the number of ancestors each has common_ancestors = list(common_ancestors) if len(common_ancestors) > 1: common_ancestors.sort(key=lambda tree: len(list(tree.get_ancestors()))) # the least common ancestor is the one with the most ancestors self.least_common_ancestor = common_ancestors[-1] return self.least_common_ancestor def get_all_paths_to_common_ancestor(self): """return all the ancestors of all the input trees up to the least common ancestor""" trees = set(self.start_trees) for tree in self.start_trees: for parent in tree.get_ancestors(): trees.add(parent) if parent == self.least_common_ancestor: break return trees # for tree in common_ancestors: # for parent in tree.get_ancestors(): # if parent in common_ancestors # # return iter(common_ancestors).next() class LabelMinimumAction(QtGui.QAction): """This action will create a dialog box to label a minimum""" def __init__(self, minimum, parent=None): QtGui.QAction.__init__(self, "add label", parent) self.parent = parent self.minimum = minimum self.triggered.connect(self.__call__) def __call__(self, val): dialog = QInputDialog(parent=self.parent) # dialog.setLabelText("") dialog.setLabelText("set label for minimum: " + str(self.minimum.energy)) dialog.setInputMode(0) dialog.exec_() if dialog.result(): label = dialog.textValue() self.parent._minima_labels[self.minimum] = label class ColorPathAction(QtGui.QAction): """this action will color the minimum energy path to minimum1""" def __init__(self, minimum1, minimum2, parent=None): QtGui.QAction.__init__(self, "show path to %d" % (minimum2._id), parent) self.parent = parent self.minimum1 = minimum1 self.minimum2 = minimum2 self.triggered.connect(self.__call__) def __call__(self, val): self.parent._color_minimum_energy_path(self.minimum1, self.minimum2) class ColorMFPTAction(QtGui.QAction): """this action will color the minima by mean first passage times to minimum1""" def __init__(self, minimum1, parent=None): QtGui.QAction.__init__(self, "color by mfpt", parent) self.parent = parent self.minimum1 = minimum1 self.triggered.connect(self.__call__) def __call__(self, val): dialog = QInputDialog(parent=self.parent) # dialog.setLabelText("") dialog.setLabelText("Temperature for MFPT calculation") dialog.setInputMode(2) dialog.setDoubleValue(1.) dialog.exec_() if dialog.result(): T = dialog.doubleValue() self.parent._color_by_mfpt(self.minimum1, T=T) class ColorCommittorAction(QtGui.QAction): """this action will color the graph by committor probabilities""" def __init__(self, minimum1, minimum2, parent=None): QtGui.QAction.__init__(self, "color by committor %d" % (minimum2._id), parent) self.parent = parent self.minimum1 = minimum1 self.minimum2 = minimum2 self.triggered.connect(self.__call__) def __call__(self, val): dialog = QInputDialog(parent=self.parent) # dialog.setLabelText("") dialog.setLabelText("Temperature for committor calculation") dialog.setInputMode(2) dialog.setDoubleValue(1.) dialog.exec_() if dialog.result(): T = dialog.doubleValue() self.parent._color_by_committor(self.minimum1, self.minimum2, T=T) class LayoutByCommittorAction(QtGui.QAction): """this action will color the graph by committor probabilities""" def __init__(self, minimum1, minimum2, parent=None): QtGui.QAction.__init__(self, "layout by committor %d" % (minimum2._id), parent) self.parent = parent self.minimum1 = minimum1 self.minimum2 = minimum2 self.triggered.connect(self.__call__) def __call__(self, val): dialog = QInputDialog(parent=self.parent) # dialog.setLabelText("") dialog.setLabelText("Temperature for committor calculation") dialog.setInputMode(2) dialog.setDoubleValue(1.) dialog.exec_() if dialog.result(): T = dialog.doubleValue() self.parent._layout_by_committor(self.minimum1, self.minimum2, T=T) class DGraphWidget(QWidget): """ dialog for showing and modifying the disconnectivity graph Parameters ---------- database : Database object graph : networkx Graph, optional you can bypass the database and pass a graph directly. if you pass the graph, pass None as the database params : dict initialize the values for the disconnectivity graph """ def __init__(self, database, graph=None, params={}, parent=None): super(DGraphWidget, self).__init__(parent=parent) self.database = database self.graph = graph self.ui = dgraph_browser.Ui_Form() self.ui.setupUi(self) self.canvas = self.ui.widget.canvas # self.ui.wgt_mpl_toolbar = NavigationToolbar() # self.toolbar = self. self.input_params = params.copy() self.params = {} self.set_defaults() self.minimum_selected = Signal() # self.minimum_selected(minim) self._selected_minimum = None # self.rebuild_disconnectivity_graph() self.colour_tree = [] self.tree_selected = None self._tree_cid = None self._minima_cid = None self._minima_labels = dict() # # populate the dropdown list with the color names # self._colors = sorted(col.cnames.keys()) # self.ui.comboBox_colour.addItems(self._colors) # [self.ui.comboBox_colour.addItem(s) for s in self._colors] # self.ui.comboBox_colour.activated[str].connect(self._color_tree) def _set_checked(self, keyword, default): """utility to set the default values for check boxes objects must have the name chkbx_keyword """ if keyword in self.input_params: v = self.input_params[keyword] else: v = default line = "self.ui.chkbx_%s.setChecked(bool(%d))" % (keyword, v) exec(line) def _set_lineEdit(self, keyword, default=None): """utility to set the default values for lineEdit objects objects must have the name lineEdit_keyword """ if keyword in self.input_params: v = self.input_params[keyword] else: v = default if v is not None: line = "self.ui.lineEdit_%s.setText(str(%s))" % (keyword, str(v)) exec(line) def set_defaults(self): self._set_checked("center_gmin", True) self._set_checked("show_minima", True) self._set_checked("order_by_energy", False) self._set_checked("order_by_basin_size", True) self._set_checked("include_gmin", True) self._set_checked("show_trees", False) # self.ui.chkbx_show_minima.setChecked(True) # self.ui.chkbx_order_by_energy.setChecked(False) # self.ui.chkbx_order_by_basin_size.setChecked(True) # self.ui.chkbx_include_gmin.setChecked(True) self._set_lineEdit("Emax") self._set_lineEdit("subgraph_size") self._set_lineEdit("nlevels") # self.line_width = 0.5 self._set_lineEdit("linewidth", default=0.5) def _get_input_parameters(self): self.params = self.input_params.copy() if "show_minima" in self.params: self.params.pop("show_minima") params = self.params Emax = self.ui.lineEdit_Emax.text() if len(Emax) > 0: self.params["Emax"] = float(Emax) subgraph_size = self.ui.lineEdit_subgraph_size.text() if len(subgraph_size) > 0: self.params["subgraph_size"] = int(subgraph_size) nlevels = self.ui.lineEdit_nlevels.text() if len(nlevels) > 0: self.params["nlevels"] = int(nlevels) offset = self.ui.lineEdit_offset.text() if len(offset) > 0: params["node_offset"] = float(offset) line_width = self.ui.lineEdit_linewidth.text() if len(line_width) > 0: self.line_width = float(line_width) self.title = self.ui.lineEdit_title.text() params["center_gmin"] = self.ui.chkbx_center_gmin.isChecked() self.show_minima = self.ui.chkbx_show_minima.isChecked() params["order_by_energy"] = self.ui.chkbx_order_by_energy.isChecked() params["order_by_basin_size"] = self.ui.chkbx_order_by_basin_size.isChecked() params["include_gmin"] = self.ui.chkbx_include_gmin.isChecked() self.show_trees = self.ui.chkbx_show_trees.isChecked() # @pyqtSlot(str) # def _color_tree(self, colour): # if self.tree_selected is not None: # c = col.hex2color(col.cnames[str(colour)]) # print "coloring tree", colour, self.tree_selected # # for tree in self.tree_selected.get_all_trees(): # tree.data["colour"] = c # # self.redraw_disconnectivity_graph() ## self.tree_selected = None @pyqtSlot() def on_btnRedraw_clicked(self): self.redraw_disconnectivity_graph() @pyqtSlot() def on_btnRebuild_clicked(self): self.rebuild_disconnectivity_graph() def redraw_disconnectivity_graph(self): self.params = self._get_input_parameters() self._draw_disconnectivity_graph(self.show_minima, self.show_trees) def rebuild_disconnectivity_graph(self): self._get_input_parameters() self._minima_labels = dict() self._build_disconnectivity_graph(**self.params) self._draw_disconnectivity_graph(self.show_minima, self.show_trees) def _build_disconnectivity_graph(self, **params): if self.database is not None: db = self.database apply_Emax = "Emax" in params and "T" not in params if apply_Emax: self.graph = database2graph(db, Emax=params['Emax']) else: self.graph = database2graph(db) dg = DisconnectivityGraph(self.graph, **params) dg.calculate() self.dg = dg def _get_tree_layout(self, tree): treelist = [] xlist = [] energies = [] for tree in tree.get_all_trees(): xlist.append(tree.data["x"]) treelist.append(tree) if tree.is_leaf(): energies.append(tree.data["minimum"].energy) else: energies.append(tree.data["ethresh"]) return treelist, xlist, energies def _on_pick_tree(self, event): """a matplotlib callback function for when a tree is clicked on""" if event.artist != self._treepoints: # print "you clicked on something other than a node" return True ind = event.ind[0] self.tree_selected = self._tree_list[ind] print "tree clicked on", self.tree_selected # launch a color selector dialog and color # all subtrees by the selected color color_dialog = QColorDialog(parent=self) color_dialog.exec_() if color_dialog.result(): color = color_dialog.selectedColor() rgba = color.getRgbF() # red green blue alpha print "color", rgba rgb = rgba[:3] for tree in self.tree_selected.get_all_trees(): tree.data["colour"] = rgb self.redraw_disconnectivity_graph() def _color_minimum_energy_path(self, m1, m2): """find the minimum energy path between m1 and m2 and color the dgraph appropriately""" # add weight attribute to the graph # note: this is not actually the minimum energy path. # This minimizes the sum of energies along the path # TODO: use minimum spanning tree to find the minimum energy path path = minimum_energy_path(self.graph, m1, m2) # emin = min(( m.energy for m in self.graph.nodes_iter() )) # for u, v, data in self.graph.edges_iter(data=True): # data["weight"] = data["ts"].energy - emin # path = nx.shortest_path(self.graph, m1, m2, weight="weight") print "there are", len(path), "minima in the path from", m1._id, "to", m2._id # color all trees up to the least common ancestor in the dgraph trees = [self.dg.minimum_to_leave[m] for m in path] ancestry = TreeLeastCommonAncestor(trees) all_trees = ancestry.get_all_paths_to_common_ancestor() # remove the least common ancestor so the coloring doesn't go to higher energies all_trees.remove(ancestry.least_common_ancestor) # color the trees for tree in all_trees: tree.data["colour"] = (1., 0., 0.) self.redraw_disconnectivity_graph() def _color_by_mfpt(self, min1, T=1.): print "coloring by the mean first passage time to get to minimum", min1._id # get a list of transition states in the same cluster as min1 edges = nx.bfs_edges(self.graph, min1) transition_states = [ self.graph.get_edge_data(u, v)["ts"] for u, v in edges ] if not check_thermodynamic_info(transition_states): raise Exception("The thermodynamic information is not yet computed") # get an arbitrary second minimum2 for ts in transition_states: if ts.minimum2 != min1: min2 = ts.minimum2 break A = [min1] B = [min2] rcalc = RatesLinalg(transition_states, A, B, T=T) rcalc.compute_rates() mfptimes = rcalc.get_mfptimes() tmax = max(mfptimes.itervalues()) def get_mfpt(m): try: return mfptimes[m] except KeyError: return tmax self.dg.color_by_value(get_mfpt) self.redraw_disconnectivity_graph() def _color_by_committor(self, min1, min2, T=1.): print "coloring by the probability that a trajectory gets to minimum", min1._id, "before", min2._id # get a list of transition states in the same cluster as min1 edges = nx.bfs_edges(self.graph, min1) transition_states = [ self.graph.get_edge_data(u, v)["ts"] for u, v in edges ] if not check_thermodynamic_info(transition_states): raise Exception("The thermodynamic information is not yet computed") A = [min2] B = [min1] committors = compute_committors(transition_states, A, B, T=T) def get_committor(m): try: return committors[m] except KeyError: return 1. self.dg.color_by_value(get_committor) self.redraw_disconnectivity_graph() def _layout_by_committor(self, min1, min2, T=1.): print "coloring by the probability that a trajectory gets to minimum", min1._id, "before", min2._id # get a list of transition states in the same cluster as min1 edges = nx.bfs_edges(self.graph, min1) transition_states = [ self.graph.get_edge_data(u, v)["ts"] for u, v in edges ] if not check_thermodynamic_info(transition_states): raise Exception("The thermodynamic information is not yet computed") A = [min2] B = [min1] committors = compute_committors(transition_states, A, B, T=T) print "maximum committor", max(committors.values()) print "minimum committor", min(committors.values()) print "number of committors near 1", len([v for v in committors.values() if v > 1.-1e-4]) print "number of committors equal to 1", len([v for v in committors.values() if v == 1.]) def get_committor(m): try: return committors[m] except KeyError: return 1. self.dg.get_value = get_committor self.dg._layout_x_axis(self.dg.tree_graph) self.dg.color_by_value(get_committor) self.redraw_disconnectivity_graph() def _on_left_click_minimum(self, minimum): print "you clicked on minimum with id", minimum._id, "and energy", minimum.energy self.minimum_selected(minimum) self._selected_minimum = minimum self.ui.label_selected_minimum.setText("%g (%d)" % (minimum.energy, minimum._id)) def _on_right_click_minimum(self, minimum): """create a menu with the list of available actions""" menu = QtGui.QMenu("list menu", parent=self) action1 = LabelMinimumAction(minimum, parent=self) menu.addAction(action1) if self._selected_minimum is not None: action2 = ColorPathAction(minimum, self._selected_minimum, parent=self) menu.addAction(action2) menu.addAction(ColorCommittorAction(minimum, self._selected_minimum, parent=self)) menu.addAction(LayoutByCommittorAction(minimum, self._selected_minimum, parent=self)) action3 = ColorMFPTAction(minimum, parent=self) menu.addAction(action3) menu.exec_(QtGui.QCursor.pos()) def _on_pick_minimum(self, event): """matplotlib event called when a minimum is clicked on""" if event.artist != self._minima_points: # print "you clicked on something other than a node" return True ind = event.ind[0] min1 = self._minima_list[ind] if event.mouseevent.button == 3: self._on_right_click_minimum(min1) else: self._on_left_click_minimum(min1) def _draw_disconnectivity_graph(self, show_minima=True, show_trees=False): ax = self.canvas.axes ax.clear() ax.hold(True) dg = self.dg # plot the lines and set up the rest of the plot using the built in function # this might change some of the minima x positions, so this has to go before # anything dependent on those positions dg.plot(axes=ax, show_minima=False, linewidth=self.line_width, title=self.title) if len(self._minima_labels) > 0: dg.label_minima(self._minima_labels, axes=ax) self.ui.widget.canvas.fig.tight_layout() # if show_trees if self._tree_cid is not None: self.canvas.mpl_disconnect(self._tree_cid) self._tree_cid = None if show_trees: # draw the nodes tree_list, x_pos, energies = self._get_tree_layout(dg.tree_graph) treepoints = ax.scatter(x_pos, energies, picker=5, color='red', alpha=0.5) self._treepoints = treepoints self._tree_list = tree_list # def on_pick_tree(event): # if event.artist != treepoints: # # print "you clicked on something other than a node" # return True # ind = event.ind[0] # self.tree_selected = tree_list[ind] # print "tree clicked on", self.tree_selected # # color_dialog = QColorDialog(parent=self) # color_dialog.exec_() # color = color_dialog.selectedColor() # rgba = color.getRgbF() # red green blue alpha # print "color", rgba # rgb = rgba[:3] # for tree in self.tree_selected.get_all_trees(): # tree.data["colour"] = rgb self._tree_cid = self.canvas.mpl_connect('pick_event', self._on_pick_tree) #draw minima as points and make them interactive if self._minima_cid is not None: self.canvas.mpl_disconnect(self._minima_cid) self._minima_cid = None if show_minima: xpos, minima = dg.get_minima_layout() energies = [m.energy for m in minima] self._minima_points = ax.scatter(xpos, energies, picker=5) self._minima_list = minima # def on_pick_min(event): # if event.artist != points: # # print "you clicked on something other than a node" # return True # ind = event.ind[0] # min1 = minima[ind] # print "you clicked on minimum with id", min1._id, "and energy", min1.energy # self.minimum_selected(min1) self._minima_cid = self.canvas.mpl_connect('pick_event', self._on_pick_minimum) self.canvas.draw() class DGraphDialog(QtGui.QMainWindow): def __init__(self, database, graph=None, params={}, parent=None, app=None): super(DGraphDialog, self).__init__(parent=parent) self.setWindowTitle("Disconnectivity graph") self.dgraph_widget = DGraphWidget(database, graph, params, parent=self) self.setCentralWidget(self.dgraph_widget) def rebuild_disconnectivity_graph(self): self.dgraph_widget.rebuild_disconnectivity_graph() def reduced_db2graph(db, Emax): ''' make a networkx graph from a database including only transition states with energy < Emax ''' from pele.storage.database import Minimum g = nx.Graph() # js850> It's not strictly necessary to add the minima explicitly here, # but for some reason it is much faster if you do (factor of 2). Even # if this means there are many more minima in the graph. I'm not sure # why this is. This step is already often the bottleneck of the d-graph # calculation. minima = db.session.query(Minimum).filter(Minimum.energy <= Emax) g.add_nodes_from(minima) # if we order by energy first and add the transition states with the largest # the we will take the smallest energy transition state in the case of duplicates ts = db.session.query(TransitionState).filter(TransitionState.energy <= Emax)\ .order_by(-TransitionState.energy) for t in ts: g.add_edge(t.minimum1, t.minimum2, ts=t) return g if __name__ == "__main__": db = Database("lj31.db", createdb=False) if len(db.minima()) < 2: raise Exception("database has no minima") if True: from pele.systems import LJCluster from pele.thermodynamics import get_thermodynamic_information system = LJCluster(31) get_thermodynamic_information(system, db, nproc=10) app = QApplication(sys.argv) md = DGraphDialog(db) md.show() md.rebuild_disconnectivity_graph() sys.exit(app.exec_())