```'''
Created on Jun 14, 2010

@author: jonathanfriedman
'''

import cPickle as pickle
import matplotlib.pyplot as plt
import networkx as nx
import scipy.stats as stats
import numpy as np

from copy import deepcopy
from networkx import Graph, DiGraph
from pandas import DataFrame as DF

def make_network(frame, cutoff=None, **kwargs):
## check that row and column labels are identical
err_msg = 'Frame must have the same labels for rows and columns'
np.testing.assert_array_equal(frame.index, frame.columns, err_msg)
## check if square matrix
err_msg = 'Frame must be symmetric'
np.testing.assert_array_equal(frame, frame.T, err_msg)

G = Graph()
if cutoff is not None:
add_edges(G, frame, cutoff, **kwargs)
return G

def add_edges(G, frame, cutoff, method='larger', abs_val=True, nodes=None):
'''
Add edges to network using specified method.

method [str]: method used to obtain statistic
'''
from itertools import combinations, product
## check that row and column labels are identical
## check if square matrix
ids = G.nodes_iter()

# if node is given iterate only over pairs containing that node
if nodes is not None:
if not hasattr(nodes, '__iter__'):
nodes = (nodes,)
pairs = product(ids, nodes)
else:
pairs = combinations(ids,2)

## go over all pairs
method = method.strip().lower()
if method not in ('larger', 'smaller'):
raise ValueError, 'Unsupported method "%s"' %method
for pair in pairs:
p1,p2  = pair
weight = frame[p1][p2]
sign   = np.sign(weight)
if abs_val:
stat = np.abs(weight)
else:
stat = weight
if method=='larger' and stat >= cutoff:
elif method=='smaller' and stat <= cutoff:
G.add_edge(p1, p2, weight=weight, sign=sign)

#def make_network(abunds, algo, **kwargs):
#    '''
#    Make network object from abundance time-series.
#    Inputs:
#        abunds = [MultiAbundanceTS] object with all OTU abundance time-series.
#        algo   = [str] algorithm to compute interaction support.
#                 Valid inputs:
#                     Tau  = kendall tau correlation coefficient
#                     MI   = Mutual information
#                     TE   = transfer entropy
#                     PI   = predictability improvement
#                     SVAR = sparse vector autoregression
#    '''
#    if algo  == 'Tau':
#        norm_tot     = kwargs.get('norm_tot', False)
#        noisy_counts = kwargs.get('noisy_counts', False)
#        lag          = kwargs.get('lag', 0)
#        tau, pval = abunds.kendall(lag = lag, norm_tot = norm_tot, noisy_counts = noisy_counts)
#        tau_mat   = PairMatrix(tau,  abunds.keys())
#        pval_mat  = PairMatrix(pval, abunds.keys())
#        if lag: net = OTUnetworkDi(abunds = abunds, matrix = tau_mat, p_vals = {'direct':pval_mat}, algo = algo, lag = lag)
#        else:   net = OTUnetwork(  abunds = abunds, matrix = tau_mat, p_vals = {'direct':pval_mat}, algo = algo, lag = 0)
#
#    elif algo == 'PI':
#        R      = kwargs.get('R', 2)
#        PI     = abunds.pred_imporve(R = R)
#        PI_mat = PairMatrix(PI,  abunds.keys())
#        net    = OTUnetworkDi(abunds = abunds, matrix = PI_mat, algo = algo)
#
#    elif algo == 'SVAR':
#        LV  = kwargs.get('LV', False)
#        ids, b, b0, cv_errors = abunds.SVAR_all(1, LV = LV)
#        b_mat = PairMatrix(b, ids)
#        net   = OTUnetworkDi(abunds = abunds, matrix = b_mat, algo = algo)
#
#    elif algo == 'MI':
#        method = kwargs.get('method', 'cmeans_ML')
#        k      = kwargs.get('k', 3)
#        r      = kwargs.get('r', 2)
#        MI      = abunds.MI(method = method, k=k, r =r)
#        MI_mat  = PairMatrix(MI,  abunds.keys())
#        net = OTUnetwork(abunds = abunds, matrix = MI_mat, algo = algo)
#
#    elif algo == 'TE':
#        method = kwargs.get('method', 'cmeans_ML')
#        k      = kwargs.get('k', 3)
#        r      = kwargs.get('r', 2)
#        TE      = abunds.transfer_entropy(k=k, method = method)
#        TE_mat  = PairMatrix(TE,  abunds.keys())
#        net = OTUnetwork(abunds = abunds, matrix = TE_mat, algo = algo)
#    return net

def remove_edges(G):
'''
Remove all edges from graph
'''
G.remove_edges_from(G.edges())
pass

def remove_disc_nodes(G):
'''
Remove all nodes that have no edges.
Return new instance.
'''
from copy import deepcopy
G_reduced = G.copy()
degree    = G.degree()
nodes     = degree.keys()
remove    = filter(lambda node: degree[node]==0, nodes)
G_reduced.remove_nodes_from(remove)
return G_reduced

def community_detection(G, **kwargs):
'''
Do community detection on network.
Return list of lists representing the nodes in each community.

Online doc: http://igraph.sourceforge.net/doc/python/index.html
'''
import igraph
algo  = kwargs.get('algo','modularity') # algorithm to be used for community detection
## convert to iGraph
file = 'temp.gml'
nx.write_gml(G,file)
## detect community
if algo is 'modularity': comm = iG.community_fastgreedy(weights = 'weight')
elif algo is 'walk':
steps  = kwargs.get('steps',4)
comm = iG.community_walktrap(weights = 'weight', steps = steps)
#    elif algo is 'spinglass':
#        gamma  = kwargs.get('gamma',1)
#        comm = iG.community_spinglass(weights = 'weight', gamma = gamma)
m  = np.array(comm.membership)
## convert membership to list of lists, and make membership dict
partitions   = []
membership_d = {}
for i in range(m.max()+1):
members = np.where(m ==i)[0]
member_ids = map(lambda j: iG.vs[int(j)].attributes()['label']  ,members)
partitions.append(member_ids)
for id in member_ids: membership_d[id] = i
return partitions, membership_d

def component_net(G, components, **kwargs):
'''
Make a new network where each node represents all nodes of G in a given component.
Interaction matrix is the average interaction between all pairs in components.
'''
from itertools import combinations, product

statistic = kwargs.get('statistic','matrix')
method    = kwargs.get('method','abs_larger')

if   statistic == 'p_vals': stat_mat = G.p_vals[method]
elif statistic == 'q_vals': stat_mat = G.q_vals[method]
elif statistic == 'matrix': stat_mat = G.matrix

n = len(components)
## make component interactions
mat = np.zeros((n,n))
for pair in combinations(range(n),2):
c1, c2 = pair
nodes1 = components[c1]
nodes2 = components[c2]
vals   = np.array( map( lambda p: stat_mat[p[0]][p[1]], product(nodes1,nodes2) ) )
if linkage is 'avg': c_inter = vals.mean()
elif linkage is 'max':
abs_v   = np.abs( vals )
c_inter = vals[abs_v.argmax()]
mat[c1,c2] = c_inter
mat[c2,c1] = mat[c1,c2]
## get within component interactions
for c,nodes in enumerate(components):
vals   = np.array( map( lambda p: stat_mat[p[0]][p[1]], combinations(nodes,2) ) )
if linkage is 'avg': c_inter = vals.mean()
elif linkage is 'max':
abs_v   = np.abs( vals )
c_inter = vals[abs_v.argmax()]
mat[c,c] = c_inter
mat_md = MatrixDictionary()
labels = map(lambda i: 'comp_' + str(i), range(n))
mat_md.from_matrix(mat, labels, labels)
comp_net = OTUnetwork(node_ids = mat_md.row_labels(), matrix = mat_md, algo = G.algo)
return comp_net

def compute_pvals(G, method='local_p', approx_norm=True, two_tailed=True):
'''
Compute p-values using specified method.
method [str]:
local_p     : calculate p-value from raw and column, assuming values are normally distributed.
The test statistic -[log(p_row) - log(p_col)] has an Erlang distribution with shape (k) = 2, and scale (theta) = 1.
global_p    : calculate from all matrix, assuming values are normally distributed
approx_norm [bool]: approximate distribution of similarity matrix elements as normal.
'''
matrix = G.matrix.matrix
ids    = G.matrix.ids
n      = matrix.shape[0] # num ts in network

p_vals_mat = np.zeros((n,n))
# calculate mean and std of all rows and cols
if method == 'local_p':
row_mean = np.zeros(n)
row_std  = np.zeros(n)
col_mean = np.zeros(n)
col_std  = np.zeros(n)
for i in range(n):
row = np.delete(matrix[i,:], i)
col = np.delete(matrix[:,i], i)
row_mean[i] = row.mean()
row_std[i]  = row.std()
col_mean[i] = col.mean()
col_std[i]  = col.std()

# calculate empirical distribution of all matrix, excluding the diagonal values
if method == 'global_p':
vals = G.matrix.flatten()
if approx_norm:
global_mean = vals.mean()
global_std  = vals.std()
else: kde_pdf     = stats.gaussian_kde(vals)

## go over all pairs
for i in range(n):
for j in range(0,n):
if i ==3 and j==37:
pass
if i == j:
p_vals_mat[i,j]  = 1
continue

if method == 'local_p':
if approx_norm:
row = np.delete(matrix[i,:], np.where(matrix[i,:] == 1.5) )
col = np.delete(matrix[:,j], np.where(matrix[:,j] == 1.5) )
if len(row) < 2 or len(col) < 2:
p_vals_mat[i,j] = 1.5
continue
p_row = stats.norm(row_mean[i], row_std[i]).sf(matrix[i,j])
p_col = stats.norm(col_mean[j], col_std[j]).sf(matrix[i,j])
else:
kde_pdf_row = stats.gaussian_kde(row)
kde_pdf_col = stats.gaussian_kde(col)
p_row       = 1 - kde_pdf_row.integrate_box_1d(-1, matrix[i,j])
p_col       = 1 - kde_pdf_col.integrate_box_1d(-1, matrix[i,j])
if matrix[i,j] < 0:  # check for extreme negative values (for dissimilarity matrices with negative values)
p_row = 1 - p_row
p_col = 1 - p_col
if not min(p_row,p_col): statistic = - np.log(p_row) - np.log(p_col)
else:                    statistic = - np.log(p_row) - np.log(p_col)
p_vals_mat[i,j] = stats.gamma(2).sf(statistic)

elif method == 'global_p':
if approx_norm:
p = stats.norm(global_mean, global_std).sf(matrix[i,j])
if matrix[i,j] < 0: p = 1 - p # check for extreme negative values (for dissimilarity matrices with negative values)
if two_tailed:      p = 2*p
else:
if matrix[i,j] < 0:  p = kde_pdf.integrate_box_1d(-100, matrix[i,j])
else:                p = kde_pdf.integrate_box_1d(matrix[i,j], 100)
p_vals_mat[i,j] = p
p_vals_mat[p_vals_mat>1] = 1.0
p_vals = PairMatrix(p_vals_mat, ids)
G.p_vals[method] = p_vals
return p_vals

def compute_qvals(G, method='direct'):
'''
Compute q-values corresponding to p-values obtained using specified 'method'.
'''
from pysurvey.util.R_utilities import R_qvalues
p_vals = G.p_vals[method]
n      = p_vals.matrix.shape[0]

p_vals_flat = p_vals.flatten()
q_vals_flat = R_qvalues(p_vals_flat)
q_vals_mat  = p_vals.unflatten(q_vals_flat)
q_vals      = DF(q_vals_mat, p_vals.ids, p_vals.ids)
G.q_vals[method] = q_vals
return q_vals

'''
Add membership attribute to each node.
'''
for node in net.nodes_iter(): net.node[node]['component'] = membership[node]

'''
Add node labels with shortened lineages
'''
for node in net.nodes_iter():
net.node[node]['label'] = node.split('_')[1] + '_' + lineage_short(net.node[node]['lineage'])

def overlap(net1,net2):
'''
Return the overlap between two networks.
'''
edges1 = net1.edges()
edges2 = net2.edges()
shared = filter(lambda e: e in edges2, edges1) # all shared edges
agree  = filter(lambda e: net1[e[0]][e[1]]['sign'] == net2[e[0]][e[1]]['sign'], shared) # same sign edges
print filter(lambda e: e not in edges2, edges1) # all shared edges
n1       = len(edges1)
n2       = len(edges2)
n_shared = len(shared)
n_agree  = len(agree)
return n1,n2,n_shared,n_agree

def confusion_mat(net1,net2, sign=True, matched=True, **kwargs):
'''
Return the confusion matrix between two networks.
If matched, networks must have the same set of nodes.
Otherwise missing nodes are added with no connections.
Categories are +,0,- , rows are categories in net1, cols are in net2.
'''
from DataStructures.MyDataMatrix import MyDataMatrix as DM
from itertools import combinations
nodes1  = net1.nodes()
nodes2  = net2.nodes()
s1 = set(nodes1)
s2 = set(nodes2)
nodes_all = s1.union(s2)
if not matched:
else:
if s1 != s2:
raise ValueError('Networks must have exactly the same set of nodes')

edges1 = net1.edges()
edges2 = net2.edges()
if sign:
categories = ['P+','N','P-']
else:
categories = ['P','N']
k = len(categories)
confuse = DM(np.zeros( (k,k) ), index=categories, columns=categories)
for pair in combinations(nodes_all,2):
n1,n2 = pair
if net1.has_edge(*pair):
if sign:
if net1[n1][n2]['sign'] > 0: row = 'P+'
else:                        row = 'P-'
else:
row = 'P'
else:                            row = 'N'
if net2.has_edge(*pair):
if sign:
if net2[n1][n2]['sign'] > 0: col = 'P+'
else:                        col = 'P-'
else:
col = 'P'
else:                            col = 'N'
confuse[col][row] += 1

#    score_mat = kwargs.get('score_mat', np.array([ [0.0,1,1],[1,0,1],[1,1,0] ]))
#    score     = np.sum( score_mat * confuse.values )
return confuse

def distance(net1,net2):
'''
Return the distance between two matched nets ( i.e., networks that have the same set of nodes).
Distance is the average euclidean distance between the edge weights.
'''
nodes1 = net1.nodes()
nodes2 = net2.nodes()
if set(nodes1) != set(nodes2): raise ValueError('Networks must have exactly the same set of nodes')
adj1 = nx.to_numpy_matrix(net1, nodes1, dtype = float) # adjacency matrix of network 1
adj2 = nx.to_numpy_matrix(net2, nodes1, dtype = float) # adjacency matrix of network 2 with same order as of net1
k    = len(nodes1)
dist /= k*(k-1)
return dist

def add_cluster_edges(G, num_clusters, dist = 'e' ):
'''
Add edges between all nodes belonging to the same cluster.
'''
from itertools import combinations
abunds = G.abunds
for clust in range(num_clusters):
clust_abunds = abunds.filter_by_cluster([clust], dist)
ids          = clust_abunds.keys()
for pair in combinations(ids,2): G.add_edge(pair[0], pair[1], weight= 1 )

def lineage_short(lineage, style='HMP', depth_max=1):
'''
Return a str of the most resolved taxonomic unit of lineage.
'''
fields = lineage.split(';')
lin    = ''
depth  = 0
if style is 'HMP':
for tax in fields[-2::-1]:
try: float(tax)
except:
if not tax.startswith('unclassified'):
#                    tax.split('(')[0]
lin    = tax + ';' + lin
depth += 1
if depth == depth_max:
return lin
elif style is 'Knight':
assigned = filter(lambda field: field[-1] != '_' ,fields)
return assigned[-1][3:]
return lin #if didn't reach maximal depth

def plot_network(net, **kwargs):
'''
'''
if 'ax' in kwargs:
ax = kwargs['ax']
else:
plt.figure(facecolor='w', edgecolor='k')
ax = plt.gca()

G = deepcopy(net)

## input params
show       = kwargs.get('show',True)
lable_flag = kwargs.get('lable_flag',True)
show_lin   = kwargs.get('show_lin',False)
node_size  = kwargs.get('node_size',100)
edge_width = kwargs.get('edge_width', 2)
node_shape = kwargs.get('node_shape','o')
color_by   = kwargs.get('color_by','given')
node_cmap  = kwargs.get('node_cmap','Paired')
node_vmax  = kwargs.get('node_vmax', 1.)
node_vmin  = kwargs.get('node_vmin', 0.)
node_alpha = kwargs.get('node_alpha', .9)
file       = kwargs.get('file',None)
remove_disconnected = kwargs.get('remove_disconnected',False)

## remove disconnected nodes
if remove_disconnected: G = remove_disc_nodes(G)

## set node sizes
if 'node_sizes' in kwargs:
d         = kwargs['node_sizes']
node_size = [d[node] for node in G.nodes()]
if kwargs.get('scale_sizes', True):
max_node_size = kwargs.get('max_node_size', 500)
node_size = max_node_size/(1-np.log2(node_size))

#    ## create net with only positive weights
#    G_pos = deepcopy(G)
#    for (u,v,d) in G_pos.edges(data=True):
#        if d['weight'] < 0: d['weight'] = -d['weight']

## node layout
if 'pos' in kwargs: # user inputed node position
pos = kwargs['pos']
else:
layout = kwargs.get('layout','graphviz') # type of graph layout to use. ('graphviz' | 'spring')
prog   = kwargs.get('prog','fdp')          # layout program to be used if layout is graphviz. ('neato' | 'fdp' |...)
iter   = kwargs.get('iter',30)             # number of iteations to use if layout is spring
if layout is 'spring':
pos  = kwargs.get('pos',nx.spring_layout(G.to_undirected(),iterations = iter))
elif layout is 'graphviz':
#            pos = nx.pygraphviz_layout(G, prog = prog, args = '-Gmaxiter=5') # pass additional parameters to graphviz. G indicates that this is a Graph attributes.
pos = nx.pygraphviz_layout(G, prog = prog)
elif layout is 'circular':
pos = nx.drawing.layout.circular_layout(G)
## plot nodes
if color_by == 'component':
C = nx.connected_component_subgraphs(G.to_undirected()) # color nodes the same in each connected subgraph
for i,g in enumerate(C):
c   = [np.random.random()]*nx.number_of_nodes(g) # random color...
nodes = nx.draw_networkx_nodes(g,
pos,
node_size   = node_size,
node_shape  = node_shape,
node_color  = c,
cmap        = node_cmap,
vmin        = node_vmin,
vmax        = node_vmax,
with_labels = False,
alpha       = node_alpha,
ax          = ax,
linewidths  = 0
)
plot_edges(g, pos, ax, edge_width)

elif color_by in set(['degree','self_edge','given','community', 'phyla']):
if color_by is 'degree':
c = [float(G.degree(v)) for v in G] # color by degree dist.
elif color_by is 'self_edge':
c = [G.matrix[node][node] for node in G.nodes()]
elif color_by is 'given':
c = kwargs.get('node_color','LightSlateGray')
elif color_by is 'community':
m = kwargs['membership']
m_max = np.max(m.values())
c = []
for id in G.nodes():
if id in m: c.append(float(m[id]))
else:       c.append(float(m_max+1))
elif color_by is 'phyla':
c_dict = kwargs['node_color']
c = [c_dict[node] for node in G.nodes() ]
nodes = nx.draw_networkx_nodes(G,
pos,
node_size   = node_size,
node_shape  = node_shape,
node_color  = c,
cmap        = node_cmap,
vmin        = node_vmin,
vmax        = node_vmax,
with_labels = False,
alpha       = node_alpha,
ax          = ax,
linewidths  = 0
)
plot_edges(G, pos, ax, edge_width)

pos_vals = pos.values()
x_pos =  map(lambda p: p[0], pos_vals)
y_pos =  map(lambda p: p[1], pos_vals)
xmax  = np.max(x_pos)
xmin  = np.min(x_pos)
ymax  = np.max(y_pos)
ymin  = np.min(y_pos)
xdiff = (xmax-xmin)
ydiff = (ymax-ymin)
if xdiff == 0: xdiff = 1
if lable_flag:
xfactor  = kwargs.get('xfactor',.03)
yfactor  = kwargs.get('yfactor',.03)
xoffset  = kwargs.get('xoffset', xdiff*xfactor)
yoffset  = kwargs.get('yoffset', ydiff*yfactor)
plot_labels(G, pos, xoffset, yoffset, ax, show_lin = show_lin)

## set figure limits
plt.xlim(xmin = xmin - xdiff*.1, xmax = xmax + xdiff*.15 )
plt.ylim(ymin = ymin - ydiff*.075, ymax = ymax + ydiff*.075 )
#    plt.xlim(xmin = xmin - xdiff*.0, xmax = xmax + xdiff*.0 )
#    plt.ylim(ymin = ymin - ydiff*.00, ymax = ymax + ydiff*.00 )
#    plt.axes().set_aspect('equal', 'datalim')

plt.axis('off')  # turn off figure axis
set_axis = kwargs.get('set_axis','tight')
plt.axis(set_axis)
#    plt.sci(nodes)
#    plt.colorbar()
if file: plt.savefig(file, bbox_inches = 'tight')
if show: plt.show()

def plot_edges(G, pos, ax, w):
## draw edges
# get edges with positive and negative relations
#    e_pos=[(u,v) for (u,v,d) in G.edges(data=True) if 1-d['weight'] > 0]
#    e_neg=[(u,v) for (u,v,d) in G.edges(data=True) if 1-d['weight'] <=0]
e_pos=[(u,v) for (u,v,d) in G.edges(data=True) if d['weight'] > 0]
e_neg=[(u,v) for (u,v,d) in G.edges(data=True) if d['weight'] <=0]
e_pos_colors = map(lambda e:1-G[e[0]][e[1]]['weight'] ,e_pos)
e_neg_colors = map(lambda e:1-G[e[0]][e[1]]['weight'] ,e_neg)
e_pos_colors = 'g'
e_neg_colors = 'r'

#    nx.draw_networkx_edges(G, pos, edgelist=e_pos, alpha = .7, width=2,edge_color = e_pos_colors, style='solid')
#    nx.draw_networkx_edges(G, pos, edgelist=e_neg, alpha = .7, width=2,edge_color = e_neg_colors, style='solid')
cmap_pos = plt.cm.Greens
cmap_neg = plt.cm.Reds
edge_vmin_pos = 0
edge_vmax_pos = 1
edge_vmin_neg = edge_vmin_pos
edge_vmax_neg = edge_vmax_pos
a = 0.8
style_def='solid'
for e in e_pos:
#        c = 1-G[e[0]][e[1]]['weight']
c = G[e[0]][e[1]]['weight']
style = G[e[0]][e[1]].get('style', style_def)
nx.draw_networkx_edges(G, pos, edgelist=[e], alpha = a, width=w, edge_color = [c], style=style,
edge_cmap = cmap_pos, edge_vmin= edge_vmin_pos, edge_vmax = edge_vmax_pos, ax = ax)
for e in e_neg:
#        c = G[e[0]][e[1]]['weight']-1
c = -G[e[0]][e[1]]['weight']
style = G[e[0]][e[1]].get('style', style_def)
nx.draw_networkx_edges(G, pos, edgelist= [e], alpha = a, width=w,edge_color = [c], style=style,
edge_cmap = cmap_neg, edge_vmin= edge_vmin_neg, edge_vmax = edge_vmax_neg, ax = ax)

def plot_labels(G, pos, xoffset, yoffset, ax, show_lin = True):
## draw labels
labels    = {}
label_pos = {}
for node, p in pos.items():
if show_lin: labels[node] = node.split('_')[-1] + '_' + lineage_short(G.node[node]['lineage'])
else:
try:    labels[node] = node.strip('\'').strip('\"')
except: labels[node] = node
label_pos[node] = np.array([p[0] + xoffset, p[1] + yoffset])
nx.draw_networkx_labels(G, label_pos, labels = labels, font_size=8, font_weight='normal', ax = ax)

def plot_degree_dist(G, logx = False, logy = True):
plt.figure()
degree_sequence=sorted(nx.degree(G).values(),reverse=True) # degree sequence
#print "Degree sequence", degree_sequence
dmax=max(degree_sequence)

if logx and logy: plt.loglog(degree_sequence,'b-',marker='o')
elif logy:        plt.semilogy(degree_sequence,'b-',marker='o')
elif logx:        plt.semilogx(degree_sequence,'b-',marker='o')
else:             plt.plot(degree_sequence,'b-',marker='o')
plt.title("Degree rank plot")
plt.ylabel("degree")
plt.xlabel("rank")

def interaction_vs_phylo_dist(G,num_bins = 10, n_pseudo=0, norm = False, plot = True, log = False, ms = 10, alpha = 0.7, file = None, show = False, filter = None):
'''
Get the number/probability of interaction as a function of phylogenetic distance.
'''
stat_inter, phylo_dist_inter = stat_vs_phlyo_dist(G,plot = False, ms = ms, alpha = alpha, file = None, show = False, filter = filter)
stat_all, phylo_dist_all     = stat_vs_phlyo_dist(G,plot = False, ms = ms, alpha = alpha, file = None, show = False, filter = None)

## get interaction counts per distance
if log:
phylo_dist_inter = np.log10(phylo_dist_inter)
phylo_dist_all   = np.log10(phylo_dist_all)
n_all, bins   = np.histogram(phylo_dist_all, bins = num_bins)
n_inter, bins = np.histogram(phylo_dist_inter, bins = num_bins)
bin_centers   = bins[:-1] + np.diff(bins)/2

## estimate probability of interaction from counts
n_all_float = np.array(n_all, float)
n_all_pseudo   = n_all_float + n_pseudo
n_inter_pseudo = n_inter + n_pseudo
p_inter        = n_inter_pseudo/n_all_pseudo

## plot
if plot:
fig = plt.figure()
if norm:
h      = ax.plot(bin_centers,p_inter,'--o', lw = 2, ms = ms, alpha = alpha)
ylabel = 'Probability of relation'
else:
h1     = ax.plot(bin_centers,n_inter,'--o', lw = 2, ms = ms, alpha = alpha)
h2     = ax.plot(bin_centers,n_all,  '--o', lw = 2, ms = ms, alpha = alpha)
ylabel = 'Number of Pairs'
plt.legend([h1,h2],['Interacting', 'All'])
if log: plt.xlabel('log10 (Genetic dist)')
else:   plt.xlabel('Genetic dist')
plt.ylabel(ylabel)

if file: plt.savefig(file)
if show: plt.show()
return p_inter, bin_centers

def conditional_stat_dist(G, statistic = 'matrix', num_bins = 10, plot = True, log = False, file = None, show = False, sig_filter = None):
stats_all, dist = stat_vs_phlyo_dist(G,statistic = statistic, plot = False, file = None, show = False, filter = sig_filter)

## get interaction counts per distance
if log: dist = np.log10(dist)
n, bins      = np.histogram(dist, bins = num_bins)
bin_centers  = bins[:-1] + np.diff(bins)/2

## get stat as function of dist
stats_cond = []
num_points = len(dist)
for i in range(len(bins)-1):
min = bins[i]
max = bins[i+1]
inds = filter(lambda z: min<=dist[z] and dist[z]<max, range(num_points))
stats_cond.append(stats_all[inds])

kde_2d(dist,stats_all)

#    ## estimate stat pdf as a function of dist
#    stat_pdfs  = []
#    point_list = []
#    kde = True
#    if kde:
#        points = np.linspace(-1,1,100)
#        for s in stats_cond:
#            kde_pdf = stats.gaussian_kde(s)
#            pdf     = kde_pdf.evaluate(points)
##            adjust = 1
##            points, pdf = kder(s, adjust)
#            stat_pdfs.append(pdf)
#            point_list.append(points)
#    else:
#        for s in stats_cond:
#            pdf, bins    = np.histogram(s, bins = 10, normed = True, range=(-1,1))
#            stat_pdfs.append(pdf)
#        points = bins[:-1] + np.diff(bins)/2
#
#    ## plot
#    if plot:
#        fig = plt.figure()
#        ax  = fig.add_subplot(111)
#        h = []
#        for pdf,points in zip(stat_pdfs, point_list):
#            h_temp = ax.plot(points,pdf,'-', lw = 2)
#            h.append(h_temp)
#        plt.xlabel('Tau')
#        plt.ylabel('pdf')
#        plt.legend(h,map(lambda x: '%.2f' %x, bin_centers))
#
#        if file: plt.savefig(file)
#        if show: plt.show()
#    return bin_centers, stat_pdfs

def stat_vs_phlyo_dist(G, dist_mat, plot = True, log = False, statistic = 'p_vals', method = 'direct', ms = 10, alpha = 0.7, file = None, show = False, filter = None):
'''
filter['stat'] = [str] statistic to filter by ('matrix' | 'matrix_abs' | 'p_vals' | 'q_vals')
filter['min'] = [float] min stat
filter['max'] = [float] max stat
'''
from itertools import combinations
import scipy.stats as stats
if   statistic == 'p_vals': stat_mat = G.p_vals[method]
elif statistic == 'q_vals': stat_mat = G.q_vals[method]
elif statistic == 'matrix': stat_mat = G.matrix
## get list of all id pairs
ids      = stat_mat.ids   # otu ids
pair_ids = []
for pair in combinations(ids,2):
pair_ids.append(pair)

phylo_dist = dist_mat.vals_by_keys(pair_ids)   ## get phylogenetic distances
stat       = stat_mat.vals_by_keys(pair_ids)   ## get corresponding stats

## Remove stat values outside range defined by stat_min & stat_max
if filter is not None:
filter_method = filter['method']
if   filter['stat'] == 'p_vals': filter_mat = G.p_vals[filter_method]
elif filter['stat'] == 'q_vals': filter_mat = G.q_vals[filter_method]
elif filter['stat'] in set(['matrix', 'matrix_abs' ]): filter_mat = G.matrix

filter_stat = filter_mat.vals_by_pair_ids(pair_ids)
if filter['stat'] == 'matrix_abs': filter_stat = np.abs(filter_stat)

del_inds = np.array([])
if 'min' in filter: del_inds = np.r_[del_inds, np.where(filter_stat < filter['min'])[0]]
if 'max' in filter: del_inds = np.r_[del_inds, np.where(filter_stat > filter['max'])[0]]
stat       = np.delete(stat,       del_inds)
phylo_dist = np.delete(phylo_dist, del_inds)

print len(stat)
if len(stat) < 3: return  ## exist function if less than 3 pairs pass filter

if log: stat = np.log10(stat)
phylo_dist   = np.log10(phylo_dist)

## plot
if plot:
#        ### scatter plot of stat vs dist
#        fig = plt.figure()
#        ax  = fig.add_subplot(111)
#        h   = ax.plot(np.log10(phylo_dist),stat,'o', ms = ms, alpha = alpha)
#        ## add box with kendall tau and statistical support
#        tau, pval = stats.kendalltau(phylo_dist, stat)
#        stats_str = 'Tau = %.2f \n p_val = %.1e' %(tau,pval)
#        if stats_str:
#            xmin, xmax = plt.xlim()
#            ymin, ymax = plt.ylim()
#            ax.annotate(stats_str, xy=(xmin + 0.05*xmax, ymax - 0.1*abs(ymax) ),  xycoords='data',
#                        xytext=(0,0), textcoords='offset points',
#                        size=10,
#                        bbox=dict(boxstyle="round", fc=(1.0, 0.7, 0.7), ec=(1., .5, .5))
#                        )
kde_2d(phylo_dist,stat)
plt.xlabel('log10 (Genetic dist)')
if statistic == 'matrix': ylabel = G.algo
else:                     ylabel = statistic + ' ' + G.algo
if log: ylabel = 'log10 (' + ylabel +')'
plt.ylabel(ylabel)
if file: plt.savefig(file)
if show: plt.show()
return stat, phylo_dist

def stat_vs_dist(G, dist_mat, **kwargs):
from matplotlib import rc
rc('text', usetex=True)

from itertools import combinations
from utilities.smoothing import exp_smooth

statistic = kwargs.get('statistic', 'matrix')
method    = kwargs.get('method', 'direct')
if   statistic == 'p_vals': stat_mat = G.p_vals[method]
elif statistic == 'q_vals': stat_mat = G.q_vals[method]
elif statistic == 'matrix': stat_mat = G.matrix

## get dist and stat of all pairs
ids        = stat_mat.row_labels()   # otu ids
n          = len(ids)
n_pairs    = n*(n-1)/2.0
phylo_dist = np.zeros(n_pairs)
stat       = np.zeros(n_pairs)
for i,pair in enumerate(combinations(ids,2)):
p1,p2         = pair
phylo_dist[i] = dist_mat[p1][p2]
stat[i]       = stat_mat[p1][p2]

## sort
x    = np.array(sorted(phylo_dist))
inds = phylo_dist.argsort()
y    = stat[inds]

## Get smooth curves
a = kwargs.get('a',100)
xnew,ynew = exp_smooth(x,y, a= a, pad = 'none')
xabs,yabs = exp_smooth(x,abs(y), a= a, pad = 'none')

## plot
plot = kwargs.get('plot', True)
if plot:
lw       = kwargs.get('lw', 3)
fs       = kwargs.get('fs', 18)
fig      = plt.figure()
h        = plt.plot(np.log10(x),y,'o')
h_smooth = plt.plot(np.log10(xnew),ynew,'-', lw=lw)
h_abs    = plt.plot(np.log10(xabs),yabs,'--', lw=lw)
plt.legend([h,h_smooth,h_abs],['data','avg','abs avg'])
plt.xlabel('log10 (Genetic dist)', fontsize = fs)
if statistic == 'matrix': ylabel = G.algo
else:                     ylabel = statistic + ' ' + G.algo
ylabel = 'Correlation'
plt.ylabel(ylabel, fontsize = fs)
plt.ylim([-1,1])
plt.xlim([-2.0,.5])
plt.grid()

file = kwargs.get('file', False)
if file: plt.savefig(file)

show = kwargs.get('show', True)
if show: plt.show()
return x,y, xnew,ynew, xabs,yabs

def plot_pairs(G, lag = 0, method = 'direct', scatter = False, base_file = None, norm_tot = True, norm = True, show = False, max_pairs = 20):
'''
'''
file = None
pairs = sorted(G.edges(), key = lambda pair: G.p_vals[method].matrix[G.ids.index(pair[1]),G.ids.index(pair[0])] )   #G.p_vals[pair[0]][pair[1]]['weight'] )
for i,pair in enumerate(pairs):
if i == max_pairs: break
if base_file:
if scatter: file = base_file + '_pair_scatter_' +str(i) + '.pdf'
else:       file = base_file + '_pair_' +str(i) + '.pdf'
if lag: driver_flags = [1,0]
else:   driver_flags = [0,0]

row = G.ids.index(pair[1])
col = G.ids.index(pair[0])
sim = G.matrix.matrix[row,col]
algo      = G.algo
stats_str = algo + '_ %.2f' %sim
if method in G.p_vals:
p_val = G.p_vals[method].matrix[row,col]
stats_str += '\n p_val = %.1e' %p_val
if method in G.q_vals:
q_val = G.q_vals[method].matrix[row,col]
stats_str += '\n q_val = %.1e' %q_val

if scatter: G.abunds.plot_vs_other(otu_ids = pair, driver_flags = driver_flags, lag = lag, file = file, norm_tot = norm_tot, norm = norm, show = False, stats_str = stats_str)
else:       G.abunds.plot(otu_ids = pair, driver_flags = driver_flags, lag = lag, file = file, norm = norm, show = False, stats_str = stats_str)
if show and len(pairs) : plt.show()

def write_components(net, components, file):
import csv
f = open(file,'wb')
header = ['Component', 'OTU id', 'lineage']
csvWriter = csv.writer(f)
writer    = csvWriter.writerow
for i,c in enumerate(components):
writer([''])
for otu in c:
line = ['Comp_' +str(i), otu, net.node[otu]['lineage'] ]
writer(line)
#    ## write otus in components
#    n_comp = map(lambda comp: len(comp),components)
#    n_max  = (np.array(n_comp)).max()
#    for i in range(n_max):
#        line = []
#        for comp in components:
#            if len(comp) > i: line.append(comp[i])
#            else:             line.append('')
#        writer(line)
f.close()

def write_edges(G, file, csv_flag = True, by_component = True, dist = None, write_lin = True):
'''
Write all edges to file.
Inputs:
G        = [OTUnet] networks whose edges are to be written.
file     = [str] name of output file.
csv_flag = [bool] write to csv format or not.
dist     = [MatrixDictionary] phylogenetic distances between all nodes. distances will be written if given.
'''
import csv
f = open(file,'wb')
if write_lin: header = '\t'.join(['Weight','From_OTU', 'To_OTU', 'From_lineage', 'To_lineage', 'From_lineage_short', 'To_lineage_short']) +'\n'
else:         header = '\t'.join(['Weight','From_OTU', 'To_OTU']) +'\n'
if dist is not None: header.append('\t'+'phylo_dist')
if by_component: header = 'Component' +'\t' + header
if csv_flag:
csvWriter = csv.writer(f)
writer = csvWriter.writerow
else:
writer = f.write
if by_component:
C = nx.connected_component_subgraphs(G.to_undirected()) # color nodes the same in each connected subgraph
for i,g in enumerate(C):
if csv_flag: writer([''])
else:        writer('\n')
edges = g.edges()
sorted_edges = sorted(edges, key = lambda e: abs(G.edge[e[0]][e[1]]['weight']), reverse=True )
for e in sorted_edges:
node_from = e[0]
node_to   = e[1]
if node_from == node_to: continue
weight = G.edge[node_from][node_to]['weight']
lin1 = G.node[node_from]['lineage']
lin2 = G.node[node_to]['lineage']
lin1_short = lineage_short(lin1, depth_max = 2)
lin2_short = lineage_short(lin2, depth_max = 2)
line   = '\t'.join([str(i+1),'%.2f' %weight, str(node_from), str(node_to), lin1, lin2 ]) + '\n'
if dist is not None:
line = line[:-2] + '\t' + '%.1e' %dist[node_from][node_to] + '\n'
if csv_flag: writer(line.split('\t'))
else:        writer(line)
else:
edges = G.edges()
sorted_edges = sorted(edges, key = lambda e: abs(G.edge[e[0]][e[1]]['weight']), reverse=True )
for e in sorted_edges:
node_from = e[0]
node_to   = e[1]
if node_from == node_to: continue
weight = G.edge[node_from][node_to]['weight']
if dist is not None:
line = '\t'.join(['%.1e' %weight, str(node_from), str(node_to), G.node[node_from]['OTU_id'], G.node[node_to]['OTU_id'], G.node[node_from]['lineage'], G.node[node_to]['lineage'],'%.1e' %dist[node_from][node_to] ]) + '\n'
else:
t = G.node[node_from]
if write_lin:
lin1 = G.node[node_from]['lineage']
lin2 = G.node[node_to]['lineage']
lin1_short = lineage_short(lin1, depth_max = 2)
lin2_short = lineage_short(lin2, depth_max = 2)
line = '\t'.join(['%.2f' %weight, str(node_from), str(node_to), lin1, lin2, lin1_short, lin2_short ]) + '\n'
else:         line = '\t'.join(['%.2f' %weight, str(node_from), str(node_to) ]) + '\n'
if csv_flag: writer(line.split('\t'))
else:        writer(line)
f.close()

def write_pajek(G,file):
'''
Write network in pajek format
'''
G = remove_disc_nodes(G)
nx.write_pajek(G, file) ## write to pajek
## add '/r' to end of each line for windows compatibility
f = open(file,'r')
f.close()
new_lines = map(lambda line:line + '\r' , lines)
f = open(file,'w')
f.writelines(new_lines)
f.close()

def kde_2d(m1,m2):
from pylab import plot, figure, imshow, xlabel, ylabel, cm, show
from scipy import stats, mgrid, c_, reshape, random, rot90

xmin = m1.min()
xmax = m1.max()
ymin = m2.min()
ymax = m2.max()

xmin = -2.5
xmax = .5
ymin = -0.80
ymax = 1

# Perform a kernel density estimator on the results
X, Y = mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = c_[X.ravel(), Y.ravel()]
values = c_[m1, m2]
kernel = stats.kde.gaussian_kde(values.T)
Z = reshape(kernel(positions.T).T, X.T.shape)

figure()
imshow(     rot90(Z),
cmap=cm.gist_earth_r,
extent=[xmin, xmax, ymin, ymax])
plot(m1, m2, 'o',  markersize=1.5)
#    plot(m1, m2, 'mx', markersize=.5)
plt.xlim(xmin,xmax)
plt.ylim(ymin,ymax)
plt.axes().set_aspect('auto', 'box')
return kernel
#
#
#def test_confusion():
#    from HMPStructures.HMPmatrix import HMP_matrix
#    k = 3
#    node_ids = [str(i) for i in xrange(k)]
#    mat1      = np.eye(k)
#    mat1[0,1] = -1
#    mat1[1,0] = -1
#    m1 = HMP_matrix()
#    m1.from_matrix(mat1, node_ids, node_ids)
#
#    mat2      = np.eye(k)
#    mat2[2,1] = 1
#    mat2[1,2] = 1
#    m2 = HMP_matrix()
#    m2.from_matrix(mat2, node_ids, node_ids)
#
#    net1 = OTUnetwork(node_ids = node_ids, matrix = m1)
#    net2 = OTUnetwork(node_ids = node_ids, matrix = m2)
#    add_edges(net1, cutoff = 0.5, statistic = 'matrix', method = 'abs_larger')
#    add_edges(net2, cutoff = 0.5, statistic = 'matrix', method = 'abs_larger')
#    print confusion_mat(net1,net2)
#
#    from DataStructures.MyDataMatrix import MyDataMatrix as DM
#    vals = np.array([[0,1,0],
#                     [1,0,1],
#                     [0,1,0]])
#    ids = ['aa','bb','cc']
#    matrix = DM(vals, index=ids, columns=ids)
#    net = OTUnetwork(matrix)
#    print net.edges()
#
#
#
if __name__ == '__main__':
from DataStructures.MyDataMatrix import MyDataMatrix as DM
path = '/Users/jonathanfriedman/Documents/Alm/enterotype/correlations/'
file = 'SparCC_with_pseudo_xiter_0_minpresent_3.pick'
cSparCC_f = DM.fromPickle(path+file)
file = 'Spearman_no_pseudo_minpresent_3.pick'
cSpear_f = DM.fromPickle(path+file)
netSparCC_f = OTUnetwork(cSparCC_f)
netSpear_f = OTUnetwork(cSpear_f)

taxons = ['Bacteroides', 'Ruminococcus', 'Prevotella']
taxon = taxons[0]
netSpear_ff = remove_disc_nodes(netSpear_f)
netSparCC_ff = remove_disc_nodes(netSparCC_f)
confuse = confusion_mat(netSparCC_ff, netSpear_ff, matched=False, sign=False)
print confuse
#    plot_network(netSpear_ff,show=False, remove_disconnected=True)
#    plot_network(netSparCC_ff, remove_disconnected=True)

```