module gibbs_gm
  use linalgebra_f90
  use random
  implicit none

  integer, private :: nburn=1000

contains
  subroutine gibbs_sampler(ndiag,vn,atom_ref_flag,atom_fix_flag,am,nrest_per_atom,rest_per_atom,rest_per_atom_pos,  &
       rest_per_atom_dist,nburn_in)
    !
    !  Sampling from gauss-markov field using gibbs sampler. 
    !  Precision matrix defining the field is assumed to be sparse
    !  This subroutine first should be run with nburn_in value set to some positive value
    !  or negative value. If the value is negative then it is set to 100
    !  nburn_in is the number of dry runs.

    !
    !  It is assumed that conditioners have already been applied and block diagoanl terms are 
    !  identity
    !

    integer ndiag
    integer, intent(in) :: atom_ref_flag(:)
    integer, intent(in) :: atom_fix_flag(:)
    integer, intent(in) :: nrest_per_atom(:)
    integer, intent(in) :: rest_per_atom_pos(:)
    integer, intent(in) :: rest_per_atom(:)
    integer, intent(in) :: rest_per_atom_dist(:)
    real, intent(inout) :: vn(:)
    real, intent(in) :: am(:)
    integer, optional :: nburn_in
    
    !   
    !  locals
    integer nd1
    integer i,j,l,l1
    integer nv,ia,iaa,iaa1,ia1,ia2,ic,ja,jaa,jad,ja1,ja2
    integer nburn,ndrop
    integer n_atom
    integer npos,npos1,nd_ia
    real rand_gauss(3)
    real vv1,vv2,vv3
    real am11,am12,am13,am21,am22,am23,am31,am32,am33
    real v_l(3)
    integer nrr,nrr1
    real ss1
    !integer, function random_binomial1
    real, allocatable :: vo(:),vo1(:),vn1(:)
    integer, allocatable :: atom_loc_ref(:)
    integer, allocatable :: ia2_l(:),ia_visited(:)
    integer, allocatable :: index(:)
    integer, allocatable :: iuniform(:)
    real runiform
    integer isize,clock
    integer vals(8),puts(12)
    logical first
    !
    !  More locals
    real gamma
    real amm(3,3)
    !
    !  Diagonal terms are identity
    !
   
    n_atom = size(atom_ref_flag)

    if(present(nburn_in))then
       nburn=nburn_in
    else
       nburn = 0
    endif
    if(nburn.lt.0) nburn = 1000

    nv = size(vn)
    allocate(vo(nv))
    !if(nburn.gt.0) then
    !   vn(1:nv) = 0.0
    !endif
    !
    !   Should we use coupled runs to determine nburn
    !
    ndrop = 0
    allocate(atom_loc_ref(n_atom))
    allocate(ia2_l(n_atom))
    do ia=1,n_atom
       atom_loc_ref(ia) = atom_ref_flag(ia)/10
       ia2_l(ia) = 3*atom_loc_ref(ia)-2
    enddo
    nd1 = ndiag-8
    write(*,*)nburn
    allocate(vo1(nv)); allocate(vn1(nv));allocate(ia_visited(n_atom))
    vo1 = 0.0
    vn1 = 0.0
    vo = 0.0
    vn = 1.00
    first = .TRUE.
    gamma = 0.0
    !gamma = 0.001
    !call date_and_time(values=vals)
    !call random_seed(put=vals(8:8))
    call system_clock(count=clock)
    puts = clock + 37*(/(i - 1, i = 1, 12) /)
    call random_seed(put=puts)
!    call random_seed(size=isize)
!    write(*,*)isize
!    stop
!    call random_seed(

    allocate(iuniform(n_atom))
    allocate(index(n_atom))
    do ic=1,nburn+1
       !write(*,*)'nb cycle ',ic,vn(1:10)
       !write(*,*)'nb cycle ',ic,vn1(1:10)
       vo = vn
       vn = 0.0
       vo1 = vn1
       vn1 = 0.0
       ia_visited(1:n_atom) = 0
       do iaa=1,n_atom
          call random_number(runiform)
          iuniform(iaa) = int(runiform*100*n_atom)+1
          index(iaa) = iaa
       enddo
       call iheap_sort_r(n_atom,1,iuniform,index)
       !write(*,*)index(1:100)
       !stop

       do iaa1=1,n_atom
          !iaa = index(iaa1)
          iaa = iaa1
          ia_visited(iaa) = 1
          if(atom_ref_flag(iaa).le.0.or.atom_fix_flag(iaa).ne.0) cycle
          !
          !  if atom_fix_flag is not 1 then this atom is 
          !      either free: atom_fix_flag = -1
          !      or  fixed:   atom_fix_flag = 1
          !
          ia = atom_loc_ref(iaa)
          if(ia.le.0) cycle
          ia1 = ia2_l(iaa)
          ia2 = ia1 + 2
          call random_vector_gauss(3,vn(ia1:ia2))
          !vn(ia1)    = real(2*random_binomial1(1,0.5,first)-1)
          first = .FALSE.
          !vn(ia1+1)  = real(2*random_binomial1(1,0.5,first)-1)
          !vn(ia1+2)  = real(2*random_binomial1(1,0.5,first)-1)
          vn(ia1:ia2) = vn(ia1:ia2)/(1.0+gamma)**0.5
       
          vn1(ia1:ia2) = vn(ia1:ia2)
          !write(*,*)vn(ia1:ia1+2)
          !vn1(ia1)   = 2*random_binomial1(1,0.5,first)-1
          !vn1(ia1+1) = 2*random_binomial1(1,0.5,first)-1
          !vn1(ia1+2) = 2*random_binomial1(1,0.5,first)-1         
          !vn(ia1:ia2) = vn(ia1:ia2)/(1.0+gamma)
          nd_ia = nrest_per_atom(iaa)
          npos  = rest_per_atom_pos(iaa)
          npos1 = npos+nd_ia-1
          if(nd_ia.le.0) cycle
          do jad=npos,npos1
             jaa = rest_per_atom(jad)
             !ja = atom_loc_ref(jaa)
             l1 = rest_per_atom_dist(jad)
             l = nd1 + 9*abs(l1)
             ja1 = ia2_l(jaa)
             !ja2 = ja1 + 2
             !write(*,*)maxval(abs(am(l:l+8)))
             if(maxval(abs(am(l:l+8))).ge.1.0) stop
             if(l1.gt.0) then
                am11 = am(l)
                am12 = am(l+1)
                am13 = am(l+2)
                am21 = am(l+3)
                am22 = am(l+4)
                am23 = am(l+5)
                am31 = am(l+6)
                am32 = am(l+7)
                am33 = am(l+8)
                !do i=1,3
                !   do j=1,3
                !      amm(i,j) = am(l)
                !      l = l + 1
                !   enddo
                !enddo
             else
                am11 = am(l)
                am21 = am(l+1)
                am31 = am(l+2)
                am12 = am(l+3)
                am22 = am(l+4)
                am32 = am(l+5)
                am13 = am(l+6)
                am23 = am(l+7)
                am33 = am(l+8)
                !do i=1,3
                !   do j=1,3
                !      amm(j,i) = am(l)
                !      l = l + 1
                !   enddo
                !enddo
             endif
             !if(maxval(abs(amm(1:3,1:3))).le.0.0) cycle
             !if(l1.lt.0) amm = transpose(amm)
             !amm = 0.0
             if(ia_visited(jaa).eq.1) then
                !if(jaa.lt.iaa) then
                !v_l(1:3) = vn(ja1:ja2)
                !vn(ia1:ia2) = vn(ia1:ia2) - matmul(amm,v_l)
                vv1 = vn(ja1);vv2=vn(ja1+1);vv3=vn(ja1+2)
                v_l(1) = am11*vv1 + am12*vv2 + am13*vv3
                v_l(2) = am21*vv1 + am22*vv2 + am23*vv3
                v_l(3) = am31*vv1 + am32*vv2 + am33*vv3
                vn(ia1:ia2) = vn(ia1:ia2) - v_l/(1.0+gamma)

                vv1 = vn1(ja1);vv2=vn1(ja1+1);vv3=vn1(ja1+2)
                v_l(1) = am11*vv1 + am12*vv2 + am13*vv3
                v_l(2) = am21*vv1 + am22*vv2 + am23*vv3
                v_l(3) = am31*vv1 + am32*vv2 + am33*vv3
                vn1(ia1:ia2) = vn1(ia1:ia2) - v_l/(1.0+gamma) 
                !elseif(jaa.gt.iaa) then
             else
                !v_l(1:3) = vo(ja1:ja2)
                !vn(ia1:ia2) = vn(ia1:ia2) - matmul(amm,v_l)
                vv1 = vo(ja1);vv2=vo(ja1+1);vv3=vo(ja1+2)
                v_l(1) = am11*vv1 + am12*vv2 + am13*vv3
                v_l(2) = am21*vv1 + am22*vv2 + am23*vv3
                v_l(3) = am31*vv1 + am32*vv2 + am33*vv3
                vn(ia1:ia2) = vn(ia1:ia2) - v_l/(1.0+gamma)

                vv1 = vo1(ja1);vv2=vo1(ja1+1);vv3=vo1(ja1+2)
                v_l(1) = am11*vv1 + am12*vv2 + am13*vv3
                v_l(2) = am21*vv1 + am22*vv2 + am23*vv3
                v_l(3) = am31*vv1 + am32*vv2 + am33*vv3
                vn1(ia1:ia2) = vn1(ia1:ia2) - v_l/(1.0+gamma)
                !v_l(1) = am11*vo(ja1) + am12*vo(ja1+1) + am13*vo(ja1+2)
                !v_l(2) = am21*vo(ja1) + am22*vo(ja1+1) + am23*vo(ja1+2)
                !v_l(3) = am31*vo(ja1) + am32*vo(ja1+1) + am33*vo(ja1+2)
                !vn(ia1:ia2) = vn(ia1:ia2) - v_l
             endif
          enddo
       enddo
       !       if(ic.eq.nburn) then
       !          write(*,*)vn
       !          stop
       !       endif
       !write(*,*)'vn :',vn(1:100)
       !write(*,*)'vn1:',vn1(1:100)
       ss1 = 0.0
       nrr = 0
       nrr1 = 0
       do iaa=1,n_atom
          if(atom_ref_flag(iaa).le.0.or.atom_fix_flag(iaa).ne.0) cycle
          ia = atom_loc_ref(iaa)
          if(ia.le.0) cycle
          ia1 = ia2_l(iaa)
          ia2 = ia1 + 2
          ss1 = ss1 + sum((vn1(ia1:ia2)-vn(ia1:ia2))**2)
          !if(maxval(abs(vn1(ia1:ia2)-vn(ia1:ia2))).gt.0.3) then
          !   nrr1 = nrr1 + 1
          !   write(*,*)iaa,atom_fix_flag(iaa),nrest_per_atom(iaa),vn(ia1:ia2),vn1(ia1:ia2)
          !endif
          nrr = nrr + 1
       enddo
       if(sqrt(ss1/nrr).le.0.000001) then
          exit
       else
          if(mod(ic,1000).eq.0) then
!             gamma = gamma + 0.005
!             vn  = 0
!             vo  = 0
!             vn1 = 1 
!             vo1 = 0
          endif
       endif
       !write(*,*)'burn cycle ',ic,sqrt(ss1/nrr),nrr1
    enddo
    !stop
    !
    !   Consider free atoms also. They need random numbers only once
    do iaa=1,n_atom
       if(atom_ref_flag(iaa).gt.0) then
          if(atom_fix_flag(iaa).eq.-1) then
             ia = atom_ref_flag(iaa)/10
             ia1 = ia2_l(iaa)
             ia2 = ia1 + 2
             call random_vector_gauss(3,vn(ia1:ia2))
          endif
       endif
    enddo
    deallocate(atom_loc_ref)
    deallocate(ia2_l)
    deallocate(vo)

  end subroutine gibbs_sampler

  subroutine simple_atom_conditioner(ndiag,am,xcond,atom_ref_flag,atom_fix_flag,nrest_per_atom,rest_per_atom, &
    rest_per_atom_pos,rest_per_atom_dist)

    !
    !   Define block conditioners for linear equation solvers and
    !   for gibbs samplers. Conditioners are a set of 3x3 matrices
    !   that are inversion of square root of on diagonal atom matrices
    !

    integer ndiag
    integer, intent(in) :: atom_ref_flag(:)
    integer, intent(in) :: atom_fix_flag(:)
    integer, intent(in) :: nrest_per_atom(:)
    integer, intent(in) :: rest_per_atom_pos(:)
    integer, intent(in) :: rest_per_atom(:)
    integer, intent(in) :: rest_per_atom_dist(:)
    !
    real, intent(inout) :: am(:)
    real, intent(out) :: xcond(:,:,:)
    !
    !   locals
    integer i,j,ia,iaa,ja,jaa,jad,id,lm,icc,icc1,io,it,io1,it1
    integer l,l1,nd_ia
    integer n_atom,nm,ndist
    integer npos,npos1
    !
    integer, allocatable :: ref_to_cond(:)
    real(kind=8) :: xloc(3,3),xtemp(3,3)
    real(kind=8) :: xout(3,3)
    real(kind=8) :: toler=1.0d-8
    real  pw
    real temp
    real amm(3,3)
    !
    !   temps
    integer nmat_size
    !
    !   body
    pw = -0.5
    n_atom = size(atom_ref_flag)
    nm = size(am)
    !
    lm = 1
    icc = 0
    allocate(ref_to_cond(n_atom))
    ref_to_cond(1:n_atom) = 0
    !
    !   Conditioners are square root of on diagonal 3x3 matrices
    !  While conditioning save references to conditionars and make diagonal terms of 
    !  the matrix as a unit matrix
    !
    nmat_size = size(am)
    do ia=1,n_atom
       if(atom_ref_flag(ia).gt.0)then
          icc = icc + 1
          ref_to_cond(ia) = icc
          xloc(1,1) = am(lm)
          xloc(2,2) = am(lm+1)
          xloc(3,3) = am(lm+2)
          xloc(1,2) = am(lm+3)
          xloc(1,3) = am(lm+4)
          xloc(2,3) = am(lm+5)

          xloc(2,1) = xloc(1,2)
          xloc(3,1) = xloc(1,3)
          xloc(3,2) = xloc(2,3)
          xtemp = xloc
          if(atom_fix_flag(ia).eq.1) then
             temp = (xloc(1,1)+xloc(2,2)+xloc(3,3))/3.0
             if(temp.gt.0.0) then
                temp = 1.0/sqrt(temp)
             else
                temp = 0.0
             endif
             xcond(1:3,1:3,icc) = 0.0
             xcond(1,1,icc) = temp
             xcond(2,2,icc) = temp
             xcond(3,3,icc) = temp
          else
             call deigen_filter_invert_f90_r(xloc,xout,toler,pw)
             xcond(1:3,1:3,icc) = xout(1:3,1:3)
             am(lm:lm+2) = 1.0
             am(lm+3:lm+5) = 0.0
          endif
          lm = lm + 6
       endif
    enddo
    !
    !   Apply conditioners
    !
    do iaa=1,n_atom
       if(atom_ref_flag(iaa).gt.0) then
          ia = atom_ref_flag(iaa)/10
          nd_ia = nrest_per_atom(iaa)
          npos = rest_per_atom_pos(iaa)
          npos1 = npos+nd_ia-1
          icc = ref_to_cond(iaa)
          do jad=npos,npos1
             jaa = rest_per_atom(jad)
             ja = atom_ref_flag(jaa)/10
             icc1 = ref_to_cond(jaa)
             l1 = rest_per_atom_dist(jad)
             if(l1.gt.0) then
                l = ndiag + 9*(l1-1)+1
                do i=1,3
                   do j=1,3
                      amm(i,j) = am(l)
                      l = l + 1
                   enddo
                enddo
                if(atom_fix_flag(iaa).eq.1.or.atom_fix_flag(jaa).eq.1) then
                   amm = 0.0
                else
                   amm = matmul(matmul(xcond(1:3,1:3,icc),amm),xcond(1:3,1:3,icc1))
                endif
                l = ndiag + 9*(l1-1)+1
                do i=1,3
                   do j=1,3
                      am(l) = amm(i,j)
                      l = l + 1
                   enddo
                enddo
             endif
          enddo
       endif
    enddo
    !    stop
    deallocate(ref_to_cond)
    return
  end subroutine simple_atom_conditioner

  subroutine decondition_shifts(shifts,xcond,atom_ref_flag)
    !
    !   Decondition shifts. 
    !
    real, intent(in) :: xcond(:,:,:)
    integer, intent(in) :: atom_ref_flag(:)
    real, intent(inout) :: shifts(:)

    real x3(3)
    integer i,j,ip,i1,i2
    integer n_atom

    n_atom = size(atom_ref_flag)
    
    ip = 0
    do i=1,n_atom
       if(atom_ref_flag(i).gt.0) then
          ip = ip + 1
          i1 = 3*ip-2
          i2 = 3*ip
          !          write(*,*)'Before ',shifts(i1:i2)
          !do j=1,3
          !   write(*,*)xcond(j,1:3,ip)
          !enddo
          x3 = matmul(xcond(1:3,1:3,ip),shifts(i1:i2))
          !write(*,*)'After ',x3
          shifts(i1:i2) = x3
       endif
    enddo

    return
  end subroutine decondition_shifts


  subroutine random_vector_gauss(nrand,x_this)
    !
    !---Generate a vector of gaussian random numbers. 
    integer nrand
    real x_this(nrand)
    
    integer ir
    !    real gauss_random
    
    
    do ir=1,nrand
       x_this(ir) =  gauss_random()
    enddo
    
    return
  end subroutine random_vector_gauss
  !
  real function gauss_random()
    !
    !---  Gaussian random number generator. It uses random_number and generates
    !--   two random numbers. For gaussian random numbers Box and Miller 
    !---  transformation is used
    real x
    !
    real w,w1,x1,x2
    real xrand1,xrand2
    
    integer ihave
    real rand
    real, parameter :: one=1.0, vsmal=tiny(one)
    save ihave,rand
    data ihave/0/
    
    if(ihave.eq.0) then
       w = 2.0
       do while (w.gt.1.0) 
          call random_number(xrand1)
          call random_number(xrand2)
          x1 = scale(xrand1,1) - one 
          x2 = scale(xrand2,1) - one
          w = x1*x1+x2*x2 + vsmal
          !          w = (2.0*xrand(1)-0.5)**2+(2.0*xrand(2)-0.5)**2
       enddo
       w1 = sqrt((-2.0*log(w))/w)
       gauss_random = x1*w1
       rand = x2*w1
       ihave = 1
    else
       gauss_random = rand
       ihave = 0
    endif
    
    return
  end function gauss_random
  
  subroutine random_exponential_refmac(lambda,shift,erand)
    !
    !   Generate from exponential distribution
    real, intent(in) :: lambda,shift
    real, intent(out) :: erand

    real xrand

    call random_number(xrand)
    erand = -log(xrand)/lambda + shift

    return
  end subroutine random_exponential_refmac

  subroutine random_truncated_gauss(shift,tgauss)
    !
    !   generate from truncated gaussian
    !   Reference to the methods:
    !   Robert C.R. "Simulation of truncated normal variables"
    !   http://arxiv.org/abs/0907.4010v1
    real, intent(in) :: shift
    real, intent(out) :: tgauss(:)

    integer i,nrand
    real r1,r2,alpha_opt,g1

    nrand = size(tgauss)

    if(shift.le.0.0) then
       do i=1,nrand
          r1 =  gauss_random()
          do while(r1.lt.shift)
             r1 = gauss_random()
          enddo
          tgauss(i) = r1
       enddo
    else
       alpha_opt = (shift+sqrt(shift**2+4.0))/2.0
       do i=1,nrand
          call random_exponential_refmac(alpha_opt,shift,r1)
          g1 = exp(-(r1-alpha_opt)**2/2.0)
          call random_number(r2)
          do while (r2.gt.g1) 
             call random_exponential_refmac(alpha_opt,shift,r1)
             g1 = exp(-(r1-alpha_opt)**2/2.0)
             call random_number(r2)
             !if(shift.ge.4.0) then
             !   write(*,*)r2,g1,r1,alpha_opt,shift
             !endif
          enddo
          tgauss(i) = r1
       enddo
    endif
  end subroutine random_truncated_gauss

end module gibbs_gm
