"""Pretty-printing (pprint()), the 'Print' Op, debugprint() and pydotprint(). They all allow different way to print a graph or the result of an Op in a graph(Print Op) """ from copy import copy import logging import os import sys # Not available on all platforms hashlib = None import numpy np = numpy try: import pydot as pd if pd.find_graphviz(): pydot_imported = True else: pydot_imported = False except ImportError: pydot_imported = False import theano from theano import gof from theano import config from theano.compat.six import StringIO from theano.gof import Op, Apply from theano.gof.python25 import any from theano.compile import Function, debugmode from theano.compile.profilemode import ProfileMode _logger = logging.getLogger("theano.printing") def debugprint(obj, depth=-1, print_type=False, file=None, ids='CHAR', stop_on_name=False): """Print a computation graph as text to stdout or a file. :type obj: Variable, Apply, or Function instance :param obj: symbolic thing to print :type depth: integer :param depth: print graph to this depth (-1 for unlimited) :type print_type: boolean :param print_type: whether to print the type of printed objects :type file: None, 'str', or file-like object :param file: print to this file ('str' means to return a string) :type ids: str :param ids: How do we print the identifier of the variable id - print the python id value int - print integer character CHAR - print capital character "" - don't print an identifier :param stop_on_name: When True, if a node in the graph has a name, we don't print anything below it. :returns: string if `file` == 'str', else file arg Each line printed represents a Variable in the graph. The indentation of lines corresponds to its depth in the symbolic graph. The first part of the text identifies whether it is an input (if a name or type is printed) or the output of some Apply (in which case the Op is printed). The second part of the text is an identifier of the Variable. If print_type is True, we add a part containing the type of the Variable If a Variable is encountered multiple times in the depth-first search, it is only printed recursively the first time. Later, just the Variable identifier is printed. If an Apply has multiple outputs, then a '.N' suffix will be appended to the Apply's identifier, to indicate which output a line corresponds to. """ if file == 'str': _file = StringIO() elif file is None: _file = sys.stdout else: _file = file done = dict() results_to_print = [] order = [] if isinstance(obj, gof.Variable): results_to_print.append(obj) elif isinstance(obj, gof.Apply): results_to_print.extend(obj.outputs) elif isinstance(obj, Function): results_to_print.extend(obj.maker.fgraph.outputs) order = obj.maker.fgraph.toposort() elif isinstance(obj, (list, tuple)): results_to_print.extend(obj) elif isinstance(obj, gof.FunctionGraph): results_to_print.extend(obj.outputs) order = obj.toposort() else: raise TypeError("debugprint cannot print an object of this type", obj) for r in results_to_print: debugmode.debugprint(r, depth=depth, done=done, print_type=print_type, file=_file, order=order, ids=ids, stop_on_name=stop_on_name) if file is _file: return file elif file == 'str': return _file.getvalue() else: _file.flush() def _print_fn(op, xin): for attr in op.attrs: temp = getattr(xin, attr) if callable(temp): pmsg = temp() else: pmsg = temp print op.message, attr, '=', pmsg class Print(Op): """ This identity-like Op print as a side effect. This identity-like Op has the side effect of printing a message followed by its inputs when it runs. Default behaviour is to print the __str__ representation. Optionally, one can pass a list of the input member functions to execute, or attributes to print. @type message: String @param message: string to prepend to the output @type attrs: list of Strings @param attrs: list of input node attributes or member functions to print. Functions are identified through callable(), executed and their return value printed. :note: WARNING. This can disable some optimizations! (speed and/or stabilization) Detailed explanation: As of 2012-06-21 the Print op is not known by any optimization. Setting a Print op in the middle of a pattern that is usually optimized out will block the optimization. for example, log(1+x) optimizes to log1p(x) but log(1+Print(x)) is unaffected by optimizations. """ view_map = {0: [0]} def __init__(self, message="", attrs=("__str__",), global_fn=_print_fn): self.message = message self.attrs = tuple(attrs) # attrs should be a hashable iterable self.global_fn = global_fn def make_node(self, xin): xout = xin.type.make_variable() return Apply(op=self, inputs=[xin], outputs=[xout]) def perform(self, node, inputs, output_storage): xin, = inputs xout, = output_storage xout[0] = xin self.global_fn(self, xin) def grad(self, input, output_gradients): return output_gradients def R_op(self, inputs, eval_points): return [x for x in eval_points] def __eq__(self, other): return (type(self) == type(other) and self.message == other.message and self.attrs == other.attrs) def __hash__(self): return hash(self.message) ^ hash(self.attrs) def __setstate__(self, dct): dct.setdefault('global_fn', _print_fn) self.__dict__.update(dct) def c_code_cache_version(self): return (1,) class PrinterState(gof.utils.scratchpad): def __init__(self, props=None, **more_props): if props is None: props = {} if isinstance(props, gof.utils.scratchpad): self.__update__(props) else: self.__dict__.update(props) self.__dict__.update(more_props) def clone(self, props=None, **more_props): if props is None: props = {} return PrinterState(self, **dict(props, **more_props)) class OperatorPrinter: def __init__(self, operator, precedence, assoc='left'): self.operator = operator self.precedence = precedence self.assoc = assoc def process(self, output, pstate): pprinter = pstate.pprinter node = output.owner if node is None: raise TypeError("operator %s cannot represent a variable that is " "not the result of an operation" % self.operator) ## Precedence seems to be buggy, see #249 ## So, in doubt, we parenthesize everything. #outer_precedence = getattr(pstate, 'precedence', -999999) #outer_assoc = getattr(pstate, 'assoc', 'none') #if outer_precedence > self.precedence: # parenthesize = True #else: # parenthesize = False parenthesize = True input_strings = [] max_i = len(node.inputs) - 1 for i, input in enumerate(node.inputs): if (self.assoc == 'left' and i != 0 or self.assoc == 'right' and i != max_i): s = pprinter.process(input, pstate.clone( precedence=self.precedence + 1e-6)) else: s = pprinter.process(input, pstate.clone( precedence=self.precedence)) input_strings.append(s) if len(input_strings) == 1: s = self.operator + input_strings[0] else: s = (" %s " % self.operator).join(input_strings) if parenthesize: return "(%s)" % s else: return s class PatternPrinter: def __init__(self, *patterns): self.patterns = [] for pattern in patterns: if isinstance(pattern, basestring): self.patterns.append((pattern, ())) else: self.patterns.append((pattern[0], pattern[1:])) def process(self, output, pstate): pprinter = pstate.pprinter node = output.owner if node is None: raise TypeError("Patterns %s cannot represent a variable that is " "not the result of an operation" % self.patterns) idx = node.outputs.index(output) pattern, precedences = self.patterns[idx] precedences += (1000,) * len(node.inputs) pp_process = lambda input, precedence: pprinter.process( input, pstate.clone(precedence=precedence)) d = dict((str(i), x) for i, x in enumerate(pp_process(input, precedence) for input, precedence in zip(node.inputs, precedences))) return pattern % d class FunctionPrinter: def __init__(self, *names): self.names = names def process(self, output, pstate): pprinter = pstate.pprinter node = output.owner if node is None: raise TypeError("function %s cannot represent a variable that is " "not the result of an operation" % self.names) idx = node.outputs.index(output) name = self.names[idx] return "%s(%s)" % (name, ", ".join( [pprinter.process(input, pstate.clone(precedence=-1000)) for input in node.inputs])) class MemberPrinter: def __init__(self, *names): self.names = names def process(self, output, pstate): pprinter = pstate.pprinter node = output.owner if node is None: raise TypeError("function %s cannot represent a variable that is" " not the result of an operation" % self.function) names = self.names idx = node.outputs.index(output) name = self.names[idx] input = node.inputs[0] return "%s.%s" % (pprinter.process(input, pstate.clone(precedence=1000)), name) class IgnorePrinter: def process(self, output, pstate): pprinter = pstate.pprinter node = output.owner if node is None: raise TypeError("function %s cannot represent a variable that is" " not the result of an operation" % self.function) input = node.inputs[0] return "%s" % pprinter.process(input, pstate) class DefaultPrinter: def __init__(self): pass def process(self, r, pstate): pprinter = pstate.pprinter node = r.owner if node is None: return LeafPrinter().process(r, pstate) return "%s(%s)" % (str(node.op), ", ".join( [pprinter.process(input, pstate.clone(precedence=-1000)) for input in node.inputs])) class LeafPrinter: def process(self, r, pstate): if r.name in greek: return greek[r.name] else: return str(r) class PPrinter: def __init__(self): self.printers = [] def assign(self, condition, printer): if isinstance(condition, gof.Op): op = condition condition = (lambda pstate, r: r.owner is not None and r.owner.op == op) self.printers.insert(0, (condition, printer)) def process(self, r, pstate=None): if pstate is None: pstate = PrinterState(pprinter=self) elif isinstance(pstate, dict): pstate = PrinterState(pprinter=self, **pstate) for condition, printer in self.printers: if condition(pstate, r): return printer.process(r, pstate) def clone(self): cp = copy(self) cp.printers = list(self.printers) return cp def clone_assign(self, condition, printer): cp = self.clone() cp.assign(condition, printer) return cp def process_graph(self, inputs, outputs, updates=None, display_inputs=False): if updates is None: updates = {} if not isinstance(inputs, (list, tuple)): inputs = [inputs] if not isinstance(outputs, (list, tuple)): outputs = [outputs] current = None if display_inputs: strings = [(0, "inputs: " + ", ".join( map(str, list(inputs) + updates.keys())))] else: strings = [] pprinter = self.clone_assign(lambda pstate, r: r.name is not None and r is not current, LeafPrinter()) inv_updates = dict((b, a) for (a, b) in updates.iteritems()) i = 1 for node in gof.graph.io_toposort(list(inputs) + updates.keys(), list(outputs) + updates.values()): for output in node.outputs: if output in inv_updates: name = str(inv_updates[output]) strings.append((i + 1000, "%s <- %s" % ( name, pprinter.process(output)))) i += 1 if output.name is not None or output in outputs: if output.name is None: name = 'out[%i]' % outputs.index(output) else: name = output.name #backport #name = 'out[%i]' % outputs.index(output) if output.name # is None else output.name current = output try: idx = 2000 + outputs.index(output) except ValueError: idx = i if len(outputs) == 1 and outputs[0] is output: strings.append((idx, "return %s" % pprinter.process(output))) else: strings.append((idx, "%s = %s" % (name, pprinter.process(output)))) i += 1 strings.sort() return "\n".join(s[1] for s in strings) def __call__(self, *args): if len(args) == 1: return self.process(*args) elif len(args) == 2 and isinstance(args[1], (PrinterState, dict)): return self.process(*args) elif len(args) > 2: return self.process_graph(*args) else: raise TypeError('Not enough arguments to call.') use_ascii = True if use_ascii: special = dict(middle_dot="\\dot", big_sigma="\\Sigma") greek = dict(alpha="\\alpha", beta="\\beta", gamma="\\gamma", delta="\\delta", epsilon="\\epsilon") else: special = dict(middle_dot=u"\u00B7", big_sigma=u"\u03A3") greek = dict(alpha=u"\u03B1", beta=u"\u03B2", gamma=u"\u03B3", delta=u"\u03B4", epsilon=u"\u03B5") pprint = PPrinter() pprint.assign(lambda pstate, r: True, DefaultPrinter()) pprint.assign(lambda pstate, r: hasattr(pstate, 'target') and pstate.target is not r and r.name is not None, LeafPrinter()) pp = pprint """ Print to the terminal a math-like expression. """ # colors not used: orange, amber#FFBF00, purple, pink, # used by default: green, blue, grey, red default_colorCodes = {'GpuFromHost': 'red', 'HostFromGpu': 'red', 'Scan': 'yellow', 'Shape': 'cyan', 'IfElse': 'magenta', 'Elemwise': '#FFAABB', # dark pink 'Subtensor': '#FFAAFF', # purple 'Alloc': '#FFAA22'} # orange def pydotprint(fct, outfile=None, compact=True, format='png', with_ids=False, high_contrast=True, cond_highlight=None, colorCodes=None, max_label_size=70, scan_graphs=False, var_with_name_simple=False, print_output_file=True, assert_nb_all_strings=-1 ): """ Print to a file (png format) the graph of a compiled theano function's ops. :param fct: the theano fct returned by theano.function. :param outfile: the output file where to put the graph. :param compact: if True, will remove intermediate var that don't have name. :param format: the file format of the output. :param with_ids: Print the toposort index of the node in the node name. and an index number in the variable ellipse. :param high_contrast: if true, the color that describes the respective node is filled with its corresponding color, instead of coloring the border :param colorCodes: dictionary with names of ops as keys and colors as values :param cond_highlight: Highlights a lazy if by sorrounding each of the 3 possible categories of ops with a border. The categories are: ops that are on the left branch, ops that are on the right branch, ops that are on both branches As an alternative you can provide the node that represents the lazy if :param scan_graphs: if true it will plot the inner graph of each scan op in files with the same name as the name given for the main file to which the name of the scan op is concatenated and the index in the toposort of the scan. This index can be printed with the option with_ids. :param var_with_name_simple: If true and a variable have a name, we will print only the variable name. Otherwise, we concatenate the type to the var name. :param assert_nb_all_strings: Used for tests. If non-negative, assert that the number of unique string nodes in the dot graph is equal to this number. This is used in tests to verify that dot won't merge Theano nodes. In the graph, ellipses are Apply Nodes (the execution of an op) and boxes are variables. If variables have names they are used as text (if multiple vars have the same name, they will be merged in the graph). Otherwise, if the variable is constant, we print its value and finally we print the type + a unique number to prevent multiple vars from being merged. We print the op of the apply in the Apply box with a number that represents the toposort order of application of those Apply. If an Apply has more than 1 input, we label each edge between an input and the Apply node with the input's index. Green boxes are inputs variables to the graph, blue boxes are outputs variables of the graph, grey boxes are variables that are not outputs and are not used, red ellipses are transfers from/to the gpu (ops with names GpuFromHost, HostFromGpu). """ if colorCodes is None: colorCodes = default_colorCodes if outfile is None: outfile = os.path.join(config.compiledir, 'theano.pydotprint.' + config.device + '.' + format) if isinstance(fct, Function): mode = fct.maker.mode profile = getattr(fct, "profile", None) if (not isinstance(mode, ProfileMode) or not fct in mode.profile_stats): mode = None fct_fgraph = fct.maker.fgraph elif isinstance(fct, gof.FunctionGraph): mode = None profile = None fct_fgraph = fct else: raise ValueError(('pydotprint expects as input a theano.function or ' 'the FunctionGraph of a function!'), fct) if not pydot_imported: raise RuntimeError("Failed to import pydot. You must install pydot" " for `pydotprint` to work.") return g = pd.Dot() if cond_highlight is not None: c1 = pd.Cluster('Left') c2 = pd.Cluster('Right') c3 = pd.Cluster('Middle') cond = None for node in fct_fgraph.toposort(): if (node.op.__class__.__name__ == 'IfElse' and node.op.name == cond_highlight): cond = node if cond is None: _logger.warn("pydotprint: cond_highlight is set but there is no" " IfElse node in the graph") cond_highlight = None if cond_highlight is not None: def recursive_pass(x, ls): if not x.owner: return ls else: ls += [x.owner] for inp in x.inputs: ls += recursive_pass(inp, ls) return ls left = set(recursive_pass(cond.inputs[1], [])) right = set(recursive_pass(cond.inputs[2], [])) middle = left.intersection(right) left = left.difference(middle) right = right.difference(middle) middle = list(middle) left = list(left) right = list(right) var_str = {} all_strings = set() def var_name(var): if var in var_str: return var_str[var] if var.name is not None: if var_with_name_simple: varstr = var.name else: varstr = 'name=' + var.name + " " + str(var.type) elif isinstance(var, gof.Constant): dstr = 'val=' + str(numpy.asarray(var.data)) if '\n' in dstr: dstr = dstr[:dstr.index('\n')] varstr = '%s %s' % (dstr, str(var.type)) elif (var in input_update and input_update[var].variable.name is not None): if var_with_name_simple: varstr = input_update[var].variable.name + " UPDATE" else: varstr = (input_update[var].variable.name + " UPDATE " + str(var.type)) else: #a var id is needed as otherwise var with the same type will be #merged in the graph. varstr = str(var.type) if (varstr in all_strings) or with_ids: idx = ' id=' + str(len(var_str)) if len(varstr) + len(idx) > max_label_size: varstr = varstr[:max_label_size - 3 - len(idx)] + idx + '...' else: varstr = varstr + idx elif len(varstr) > max_label_size: varstr = varstr[:max_label_size - 3] + '...' idx = 1 while varstr in all_strings: idx += 1 suffix = ' id=' + str(idx) varstr = (varstr[:max_label_size - 3 - len(suffix)] + '...' + suffix) var_str[var] = varstr all_strings.add(varstr) return varstr topo = fct_fgraph.toposort() apply_name_cache = {} def apply_name(node): if node in apply_name_cache: return apply_name_cache[node] prof_str = '' if mode: time = mode.profile_stats[fct].apply_time.get(node, 0) #second, % total time in profiler, %fct time in profiler if mode.local_time == 0: pt = 0 else: pt = time * 100 / mode.local_time if mode.profile_stats[fct].fct_callcount == 0: pf = 0 else: pf = time * 100 / mode.profile_stats[fct].fct_call_time prof_str = ' (%.3fs,%.3f%%,%.3f%%)' % (time, pt, pf) elif profile: time = profile.apply_time.get(node, 0) #second, %fct time in profiler if profile.fct_callcount == 0: pf = 0 else: pf = time * 100 / profile.fct_call_time prof_str = ' (%.3fs,%.3f%%)' % (time, pf) applystr = str(node.op).replace(':', '_') applystr += prof_str if (applystr in all_strings) or with_ids: idx = ' id=' + str(topo.index(node)) if len(applystr) + len(idx) > max_label_size: applystr = (applystr[:max_label_size - 3 - len(idx)] + idx + '...') else: applystr = applystr + idx elif len(applystr) > max_label_size: applystr = applystr[:max_label_size - 3] + '...' idx = 1 while applystr in all_strings: idx += 1 suffix = ' id=' + str(idx) applystr = (applystr[:max_label_size - 3 - len(suffix)] + '...' + suffix) all_strings.add(applystr) apply_name_cache[node] = applystr return applystr # Update the inputs that have an update function input_update = {} outputs = list(fct_fgraph.outputs) if isinstance(fct, Function): for i in reversed(fct.maker.expanded_inputs): if i.update is not None: input_update[outputs.pop()] = i apply_shape = 'ellipse' var_shape = 'box' for node_idx, node in enumerate(topo): astr = apply_name(node) use_color = None for opName, color in colorCodes.items(): if opName in node.op.__class__.__name__: use_color = color if use_color is None: nw_node = pd.Node(astr, shape=apply_shape) elif high_contrast: nw_node = pd.Node(astr, style='filled', fillcolor=use_color, shape=apply_shape) else: nw_node = pd.Node(astr, color=use_color, shape=apply_shape) g.add_node(nw_node) if cond_highlight: if node in middle: c3.add_node(nw_node) elif node in left: c1.add_node(nw_node) elif node in right: c2.add_node(nw_node) for id, var in enumerate(node.inputs): varstr = var_name(var) label = str(var.type) if len(node.inputs) > 1: label = str(id) + ' ' + label if len(label) > max_label_size: label = label[:max_label_size - 3] + '...' if var.owner is None: if high_contrast: g.add_node(pd.Node(varstr, style='filled', fillcolor='green', shape=var_shape)) else: g.add_node(pd.Node(varstr, color='green', shape=var_shape)) g.add_edge(pd.Edge(varstr, astr, label=label)) elif var.name or not compact: g.add_edge(pd.Edge(varstr, astr, label=label)) else: #no name, so we don't make a var ellipse g.add_edge(pd.Edge(apply_name(var.owner), astr, label=label)) for id, var in enumerate(node.outputs): varstr = var_name(var) out = any([x[0] == 'output' for x in var.clients]) label = str(var.type) if len(node.outputs) > 1: label = str(id) + ' ' + label if len(label) > max_label_size: label = label[:max_label_size - 3] + '...' if out: g.add_edge(pd.Edge(astr, varstr, label=label)) if high_contrast: g.add_node(pd.Node(varstr, style='filled', fillcolor='blue', shape=var_shape)) else: g.add_node(pd.Node(varstr, color='blue', shape=var_shape)) elif len(var.clients) == 0: g.add_edge(pd.Edge(astr, varstr, label=label)) if high_contrast: g.add_node(pd.Node(varstr, style='filled', fillcolor='grey', shape=var_shape)) else: g.add_node(pd.Node(varstr, color='grey', shape=var_shape)) elif var.name or not compact: g.add_edge(pd.Edge(astr, varstr, label=label)) # else: #don't add egde here as it is already added from the inputs. if cond_highlight: g.add_subgraph(c1) g.add_subgraph(c2) g.add_subgraph(c3) if not outfile.endswith('.' + format): outfile += '.' + format g.write(outfile, prog='dot', format=format) if print_output_file: print 'The output file is available at', outfile if assert_nb_all_strings != -1: assert len(all_strings) == assert_nb_all_strings if scan_graphs: scan_ops = [(idx, x) for idx, x in enumerate(fct_fgraph.toposort()) if isinstance(x.op, theano.scan_module.scan_op.Scan)] path, fn = os.path.split(outfile) basename = '.'.join(fn.split('.')[:-1]) # Safe way of doing things .. a file name may contain multiple . ext = fn[len(basename):] for idx, scan_op in scan_ops: # is there a chance that name is not defined? if hasattr(scan_op.op, 'name'): new_name = basename + '_' + scan_op.op.name + '_' + str(idx) else: new_name = basename + '_' + str(idx) new_name = os.path.join(path, new_name + ext) pydotprint(scan_op.op.fn, new_name, compact, format, with_ids, high_contrast, cond_highlight, colorCodes, max_label_size, scan_graphs) def pydotprint_variables(vars, outfile=None, format='png', depth=-1, high_contrast=True, colorCodes=None, max_label_size=50, var_with_name_simple=False): ''' Identical to pydotprint just that it starts from a variable instead of a compiled function. Could be useful ? ''' if colorCodes is None: colorCodes = default_colorCodes if outfile is None: outfile = os.path.join(config.compiledir, 'theano.pydotprint.' + config.device + '.' + format) try: import pydot as pd except ImportError: print ("Failed to import pydot. You must install pydot for " "`pydotprint_variables` to work.") return g = pd.Dot() my_list = {} orphanes = [] if type(vars) not in (list, tuple): vars = [vars] var_str = {} def var_name(var): if var in var_str: return var_str[var] if var.name is not None: if var_with_name_simple: varstr = var.name else: varstr = 'name=' + var.name + " " + str(var.type) elif isinstance(var, gof.Constant): dstr = 'val=' + str(var.data) if '\n' in dstr: dstr = dstr[:dstr.index('\n')] varstr = '%s %s' % (dstr, str(var.type)) else: #a var id is needed as otherwise var with the same type will be #merged in the graph. varstr = str(var.type) varstr += ' ' + str(len(var_str)) if len(varstr) > max_label_size: varstr = varstr[:max_label_size - 3] + '...' var_str[var] = varstr return varstr def apply_name(node): name = str(node.op).replace(':', '_') if len(name) > max_label_size: name = name[:max_label_size - 3] + '...' return name def plot_apply(app, d): if d == 0: return if app in my_list: return astr = apply_name(app) + '_' + str(len(my_list.keys())) if len(astr) > max_label_size: astr = astr[:max_label_size - 3] + '...' my_list[app] = astr use_color = None for opName, color in colorCodes.items(): if opName in app.op.__class__.__name__: use_color = color if use_color is None: g.add_node(pd.Node(astr, shape='box')) elif high_contrast: g.add_node(pd.Node(astr, style='filled', fillcolor=use_color, shape='box')) else: g.add_node(pd.Nonde(astr, color=use_color, shape='box')) for i, nd in enumerate(app.inputs): if nd not in my_list: varastr = var_name(nd) + '_' + str(len(my_list.keys())) if len(varastr) > max_label_size: varastr = varastr[:max_label_size - 3] + '...' my_list[nd] = varastr if nd.owner is not None: g.add_node(pd.Node(varastr)) elif high_contrast: g.add_node(pd.Node(varastr, style='filled', fillcolor='green')) else: g.add_node(pd.Node(varastr, color='green')) else: varastr = my_list[nd] label = '' if len(app.inputs) > 1: label = str(i) g.add_edge(pd.Edge(varastr, astr, label=label)) for i, nd in enumerate(app.outputs): if nd not in my_list: varastr = var_name(nd) + '_' + str(len(my_list.keys())) if len(varastr) > max_label_size: varastr = varastr[:max_label_size - 3] + '...' my_list[nd] = varastr color = None if nd in vars: color = 'blue' elif nd in orphanes: color = 'gray' if color is None: g.add_node(pd.Node(varastr)) elif high_contrast: g.add_node(pd.Node(varastr, style='filled', fillcolor=color)) else: g.add_node(pd.Node(varastr, color=color)) else: varastr = my_list[nd] label = '' if len(app.outputs) > 1: label = str(i) g.add_edge(pd.Edge(astr, varastr, label=label)) for nd in app.inputs: if nd.owner: plot_apply(nd.owner, d - 1) for nd in vars: if nd.owner: for k in nd.owner.outputs: if k not in vars: orphanes.append(k) for nd in vars: if nd.owner: plot_apply(nd.owner, depth) try: g.write_png(outfile, prog='dot') except pd.InvocationException, e: # Some version of pydot are bugged/don't work correctly with # empty label. Provide a better user error message. if pd.__version__ == "1.0.28" and "label=]" in e.message: raise Exception("pydot 1.0.28 is know to be bugged. Use another " "working version of pydot") elif "label=]" in e.message: raise Exception("Your version of pydot " + pd.__version__ + " returned an error. Version 1.0.28 is known" " to be bugged and 1.0.25 to be working with" " Theano. Using another version of pydot could" " fix this problem. The pydot error is: " + e.message) print 'The output file is available at', outfile class _TagGenerator: """ Class for giving abbreviated tags like to objects. Only really intended for internal use in order to implement min_informative_st """ def __init__(self): self.cur_tag_number = 0 def get_tag(self): rval = debugmode.char_from_number(self.cur_tag_number) self.cur_tag_number += 1 return rval def min_informative_str(obj, indent_level=0, _prev_obs=None, _tag_generator=None): """ Returns a string specifying to the user what obj is The string will print out as much of the graph as is needed for the whole thing to be specified in terms only of constants or named variables. Parameters ---------- obj: the name to convert to a string indent_level: the number of tabs the tree should start printing at (nested levels of the tree will get more tabs) _prev_obs: should only be used by min_informative_str a dictionary mapping previously converted objects to short tags Basic design philosophy ----------------------- The idea behind this function is that it can be used as parts of command line tools for debugging or for error messages. The information displayed is intended to be concise and easily read by a human. In particular, it is intended to be informative when working with large graphs composed of subgraphs from several different people's code, as in pylearn2. Stopping expanding subtrees when named variables are encountered makes it easier to understand what is happening when a graph formed by composing several different graphs made by code written by different authors has a bug. An example output is: A. Elemwise{add_no_inplace} B. log_likelihood_v_given_h C. log_likelihood_h If the user is told they have a problem computing this value, it's obvious that either log_likelihood_h or log_likelihood_v_given_h has the wrong dimensionality. The variable's str object would only tell you that there was a problem with an Elemwise{add_no_inplace}. Since there are many such ops in a typical graph, such an error message is considerably less informative. Error messages based on this function should convey much more information about the location in the graph of the error while remaining succint. One final note: the use of capital letters to uniquely identify nodes within the graph is motivated by legibility. I do not use numbers or lower case letters since these are pretty common as parts of names of ops, etc. I also don't use the object's id like in debugprint because it gives such a long string that takes time to visually diff. """ if _prev_obs is None: _prev_obs = {} indent = ' ' * indent_level if id(obj) in _prev_obs: tag = _prev_obs[id(obj)] return indent + '<' + tag + '>' if _tag_generator is None: _tag_generator = _TagGenerator() cur_tag = _tag_generator.get_tag() _prev_obs[id(obj)] = cur_tag if hasattr(obj, '__array__'): name = '<ndarray>' elif hasattr(obj, 'name') and obj.name is not None: name = obj.name elif hasattr(obj, 'owner') and obj.owner is not None: name = str(obj.owner.op) for ipt in obj.owner.inputs: name += '\n' + min_informative_str(ipt, indent_level=indent_level + 1, _prev_obs=_prev_obs, _tag_generator=_tag_generator) else: name = str(obj) prefix = cur_tag + '. ' rval = indent + prefix + name return rval def var_descriptor(obj, _prev_obs=None, _tag_generator=None): """ Returns a string, with no endlines, fully specifying how a variable is computed. Does not include any memory location dependent information such as the id of a node. """ global hashlib if hashlib is None: try: import hashlib except ImportError: raise RuntimeError("Can't run var_descriptor because hashlib is not available.") if _prev_obs is None: _prev_obs = {} if id(obj) in _prev_obs: tag = _prev_obs[id(obj)] return '<' + tag + '>' if _tag_generator is None: _tag_generator = _TagGenerator() cur_tag = _tag_generator.get_tag() _prev_obs[id(obj)] = cur_tag if hasattr(obj, '__array__'): # hashlib hashes only the contents of the buffer, but # it can have different semantics depending on the strides # of the ndarray name = '<ndarray:' name += 'strides=['+','.join(str(stride) for stride in obj.strides)+']' name += ',digest='+hashlib.md5(obj).hexdigest()+'>' elif hasattr(obj, 'owner') and obj.owner is not None: name = str(obj.owner.op) + '(' name += ','.join(var_descriptor(ipt, _prev_obs=_prev_obs, _tag_generator=_tag_generator) for ipt in obj.owner.inputs) name += ')' elif hasattr(obj, 'name') and obj.name is not None: # Only print the name if there is no owner. # This way adding a name to an intermediate node can't make # a deeper graph get the same descriptor as a shallower one name = obj.name else: name = str(obj) if ' at 0x' in name: # The __str__ method is encoding the object's id in its str name = position_independent_str(obj) if ' at 0x' in name: print name assert False prefix = cur_tag + '=' rval = prefix + name return rval def position_independent_str(obj): if isinstance(obj, theano.gof.graph.Variable): rval = 'theano_var' rval += '{type='+str(obj.type)+'}' else: raise NotImplementedError() return rval def hex_digest(x): """ Returns a short, mostly hexadecimal hash of a numpy ndarray """ global hashlib if hashlib is None: try: import hashlib except ImportError: raise RuntimeError("Can't run hex_digest because hashlib is not available.") assert isinstance(x, np.ndarray) rval = hashlib.md5(x.tostring()).hexdigest() # hex digest must be annotated with strides to avoid collisions # because the buffer interface only exposes the raw data, not # any info about the semantics of how that data should be arranged # into a tensor rval = rval + '|strides=[' + ','.join(str(stride) for stride in x.strides) + ']' rval = rval + '|shape=[' + ','.join(str(s) for s in x.shape) + ']' return rval