#!/usr/bin/python

# The SDP example of section 8.6.

from cvxopt import base, blas, lapack, random, solvers
from cvxopt.base import matrix

def mcsdp(w):
    """
    Returns solution x, z to 

        (primal)  minimize    sum(x)
                  subject to  w + diag(x) >= 0

        (dual)    maximize    -tr(w*z)
                  subject to  diag(z) = 1
                              z >= 0.
    """

    n = w.size[0]

    def Fs(x, y, alpha=1.0, beta=0.0, trans='N'):
        """
            y := alpha*(-diag(x)) + beta*y.   
        """
	if trans=='N':
            # x is a vector; y[0] is a matrix.
	    blas.scal(beta, y[0])
	    blas.axpy(x, y[0], alpha=-alpha, incy=n+1)

	else:   
            # x[0] is a matrix; y is a vector.
	    blas.scal(beta, y)
	    blas.axpy(x[0], y, alpha=-alpha, incx=n+1)
	 

    def cngrnc(r, x, alpha=1.0):
        """
        Congruence transformation

	    x := alpha * r'*x*r.

        r and x are square matrices.  
        """

        # Scale diagonal of x by 1/2.  
        x[::n+1] *= 0.5 
    
        # a := tril(x)*r 
        a = +r
        blas.trmm(x, a, side='L')

        # x := alpha*(a*r' + r*a') 
        blas.syr2k(r, a, x, trans='T', alpha=alpha)


    def kktsolver(d, r):

        # t = r*r' as a nonsymmetric matrix.
        t = matrix(0.0, (n,n))
	blas.gemm(r[0], r[0], t, transB='T') 

	# Cholesky factorization of tsq = t.*t.
        tsq = t**2
	lapack.potrf(tsq)

	def f(x, y, zl, zs):
            """
            Solve
                          -diag(zs)               = bx
                -diag(x) - inv(r*r')*zs*inv(r*r') = bs

            On entry, x and zs contain bx and bs.  
            On exit, they contain the solution, with zs scaled
            (inv(r)'*zs*inv(r) is returned instead of zs).

            We first solve 

                ((r*r') .* (r*r')) * x = bx - diag(t*bs*t) 

            and take zs  = -r' * (diag(x) + bs) * r.
            """

            # tbst := t * zs * t = t * bs * t
            tbst = +zs[0]
            cngrnc(t, tbst) 

            # x := x - diag(tbst) = bx - diag(r*r' * bs * r*r')
            x -= tbst[::n+1]

            # x := (t.*t)^{-1} * x = (t.*t)^{-1} * (bx - diag(t*bs*t))
            lapack.potrs(tsq, x)

            # zs := zs + diag(x) = bs + diag(x)
            zs[0][::n+1] += x

            # zs := -r' * zs * r = -r' * (diag(x) + bs) * r 
            cngrnc(r[0], zs[0], alpha=-1.0)

	return f

    c = matrix(1.0, (n,1))

    # Initial feasible x:  x = 1.0 - min(lambda(w)).
    lmbda = matrix(0.0, (n,1))
    lapack.syevx(+w, lmbda, range='I', il=1, iu=1)
    x0 = matrix(-lmbda[0]+1.0, (n,1)) 
    s0 = +w
    s0[::n+1] += x0

    # Initial feasible z is identity.
    z0 = matrix(0.0, (n,n))
    z0[::n+1] = 1.0

    sol = solvers.conelp(c, kktsolver, Gs=Fs, hs=[w], 
        primalstart={'x': x0, 'ss': [s0]}, dualstart={'zs': [z0]})
    return sol['x'], sol['zs'][0]

n = 100
w = random.normal(n,n)
x, z = mcsdp(w)
