from phylonetwork import all_trees as tg
from math import factorial

def prod(l):
    temp = 1
    for x in l:
        temp = temp*x
    return temp

def kappa(tree):
    temp = {}
    for u in tree.nodes():
        temp[u] = len(tree.cluster(u));
    return temp

def probability(tree):
    kappas=kappa(tree)
    #print kappas
    n=len(tree.leaves())
    return 2.0**(n-1)/factorial(n)*prod([1.0/(kappas[u]-1) for u in tree.interior_nodes()])

def sackin(tree):
    kappas=kappa(tree)
    return sum([kappas[u] for u in tree.interior_nodes()])

def colless(tree):
    kappas=kappa(tree)
    temp = 0
    for u in tree.interior_nodes():
        children = tree.successors(u)
        k1 = kappas[children[0]]
        k2 = kappas[children[1]]
        temp += abs(k1-k2)
    return temp
        
def cophenetic(tree):
    kappas = kappa(tree)
    temp = 0
    for u in tree.interior_nodes():
        children = tree.successors(u)
        k1 = kappas[children[0]]
        k2 = kappas[children[1]]
        temp += tree.depth(u) * k1 * k2
    return temp

def expected(pr,data):
    return sum([data[tree]*pr[tree] for tree in range(len(pr))])

def variance(pr,data):
    expctd = expected(pr,data)
    return sum([((data[tree]-expctd)**2)*pr[tree] for tree in range(len(data))])

def covariance(pr,data1,data2):
    exp1=expected(pr,data1)
    exp2=expected(pr,data2)
    return sum([data1[t]*data2[t]*pr[t] for t in range(len(data1)) ])-exp1*exp2


def harmonic(n):
    return sum([1.0/i for i in range(1,n+1)])

def harmonic2(n):
    return sum([1.0/(i*i) for i in range(1,n+1)])

def expCol(n):
    return (n % 2) + n*(harmonic(n/2)-1)

def expSac(n):
    return 2*n*(harmonic(n)-1)

def expCop(n):
    return n*(n-1) - 2*n*(harmonic(n)-1)

def expCopBar(n):
    return n*(n-1)

def varCol(n):
    nf = float(n)
    return ((5*nf**2+7*n)/2 + (6*nf+1)*(n/2) - 4*(n/2)**2 + 8*((n+2)/4)**2
            - 8*(nf+1)*((n+2)/4) - 6*n*harmonic(n) 
            + (2*(n/2)-n*(n-3))*harmonic(n/2)
            - n**2 * harmonic2(n/2) + (n**2+3*n-2*(n/2))* harmonic((n+2)/4)
            - 2*n*harmonic(n/4) )

def varSac(n):
    return 7*n**2 - 4 * n**2 * harmonic2(n)- 2*n*harmonic(n) - n

def varCop(n):
    return ((1.0/12)*(n**4-10*n**3+131*n**2-2*n)
            -4*n**2 * harmonic2(n) 
            -6*n*harmonic(n))

def varCopBar(n):
    return 2.0*n*(n-1)*(n-2)*(n-3)/24

def covSacCop(n):
    return 4*n*(n*harmonic2(n)+harmonic(n))+(1.0/6)*n*(n**2-51*n+2)

def covSacCopBar(n):
    return 2*n*harmonic(n)+(1.0/6)*n*(n**2-9*n-4)

#n = 7
for n in range(3,10):
    taxa = [str(i+1) for i in range(n)]
    tgen = tg(taxa=taxa,binary=True,nested_taxa=False)
    pr = []
    sackins = []
    cophenetics = []
    collesss = []
    cophebars = []
    for t in tgen:
        pr.append(probability(t))
        s=sackin(t)
        sackins.append(s)
        c=cophenetic(t)
        cophenetics.append(c)
        cophebars.append(c+s)
        c=colless(t)
        collesss.append(c)
    print n
    print "E(sackin) = %f (%f)" % (expected(pr,sackins),expSac(n))
    print "E(cophenetic) = %f (%f)" % (expected(pr,cophenetics),expCop(n))
    print "E(cophebar) = %f (%f)" % (expected(pr,cophebars),expCopBar(n))
    print "E(colless) = %f (%f)" % (expected(pr,collesss),expCol(n))
    print "V(sackin) = %f (%f)" % (variance(pr,sackins),varSac(n))
    print "V(cophenetic) = %f (%f)" % (variance(pr,cophenetics),varCop(n))
    print "V(cophebar) = %f (%f)" % (variance(pr,cophebars),varCopBar(n))
    print "V(colless) = %f (%f)" % (variance(pr,collesss),varCol(n))
    print "Cov(sackin,cophenetic) = %f (%f)" % (covariance(pr,sackins,cophenetics),covSacCop(n))
    print "Cov(sackin,cophebar) = %f (%f)" % (covariance(pr,sackins,cophebars), covSacCopBar(n))
    


