#!/usr/bin/python 

from pyCmpl import *
import sys
import os
import shutil
import tempfile
import time



class PMedianException(Exception):
    pass


#Big M used for the WLP model
bigM = 1000000000


#*************** PMedian  ************************************************
class PMedian(object):

    #*********** constructor **********
    def __init__(self, nrOfSources, nrOfDests, p, varCosts, supplies,demands, isCap, isDemandWeighted, LogLabHomePath):
        
        self.__home = LogLabHomePath  
        cmplHomePath =  LogLabHomePath + 'Cmpl' 
        os.environ.update({'CMPLHOME': cmplHomePath })
        
        self.__model = None
        if isCap:
            self.__isCap = 1
        else:
            self.__isCap = 0
        
        if isDemandWeighted:
            self.__isDemandWeighted = 1
        else:
            self.__isDemandWeighted = 0

        self.__varCosts = varCosts
        self.__nrOfDestinations = nrOfDests 
        self.__nrOfSources = nrOfSources
        self.__p = p

        self.__NodeHeuristics = 100
        self.__mipGap = 0.001
        self.__debug = False
        
        self.__nrOfNodes = self.__nrOfDestinations + self.__nrOfDestinations

        if self.__mipGap==0:
            if self.__nrOfNodes < 50:
                self.__mipGap = 0.001
            elif self.__nrOfNodes < 100:
                self.__mipGap = 0.02
            elif self.__nrOfNodes < 150:
                self.__mipGap = 0.03
            else:
                self.__mipGap = 0.05

        
        self.__edges = []
        self.__vCosts = []

        self.__costTolAdder = 0
        self.__maxCostSrc = []
        self.__minCostSrc = []
        self.__avgCostSrc = []
        self.__costTolSrc = []

        for i in range(self.__nrOfSources):
            minCost=0
            maxCost=bigM
            sumCosts=0
            for j in range(self.__nrOfDestinations):
                cost = self.__varCosts[i][j]
                if cost>maxCost:
                    maxCost=cost
                sumCosts += cost
                if cost < minCost:
                    minCost  = cost
            self.__minCostSrc.append(minCost)
            self.__maxCostSrc.append(maxCost)
            avgCost = sumCosts/self.__nrOfDestinations
            self.__avgCostSrc.append(avgCost)


        if not isinstance(supplies,list):
            raise PMedianException(f'PMedian error: Wrong supply vector ')
        else:
            self.__supplies = supplies  

        if not isinstance(demands,list):
            raise PMedianException(f'PMedian error: Wrong supply vector ')
        else:
            self.__demands = demands  
        
        self.__totalCosts = 0
        self.__activeSources = [ 0  for a in range(nrOfSources)  ]
        self.__destSourceMapping =  [ -1  for a in range(nrOfDests)  ]

        self.solutionStatus = False

        self.__tmpPath = tempfile.gettempdir()+os.sep
        
    #*********** end constructor ******
    

    def setQuantities(self, supplies, demands):
        if not isinstance(supplies,list):
            raise PMedianException(f'PMedian error: Wrong supply vector ')
        else:
            self.__supplies = supplies  

        if not isinstance(demands,list):
            raise PMedianException(f'PMedian error: Wrong supply vector ')
        else:
            self.__demands = demands  
        
        self.__isCap = 1

    def getObjValue(self):
        return self.__totalCosts

    def getActiveSources(self):
        return self.__activeSources

    def getDestSourceMapping(self):
        return self.__destSourceMapping

    def isSolution(self):
        return self.solutionStatus

     #*********** end readTppFile ************       

    def __setCostTol(self):
        factor = 1.25
        for i in range(self.__nrOfSources):  
            if self.__avgCostSrc[i] > self.__minCostSrc[i]:
                self.__costTolSrc.append( self.__avgCostSrc[i]/self.__maxCostSrc[i] * factor)
            else: 
                self.__costTolSrc.append( self.__minCost[i]/self.__maxCostSrc[i] * factor )


    
    #*********** __setEdges     ************
    def __setEdges(self ):

        destInEdges = [ False for d in range(self.__nrOfDestinations)]

        while True:
            self.__edges = []
            self.__vCosts = []
                
            for i in range(self.__nrOfSources):
                self.__costTolSrc[i] += self.__costTolAdder
                maxCost = self.__maxCostSrc[i] * self.__costTolSrc[i] 

                for j in range(self.__nrOfDestinations):
                    if self.__isCap:
                        if self.__varCosts[i][j] <= maxCost:
                            self.__edges.append((i+1,j+1))
                            self.__vCosts.append(self.__varCosts[i][j]) 
                            destInEdges[j] = True
                    else:
                        self.__edges.append((i+1,j+1))
                        self.__vCosts.append(self.__varCosts[i][j]) 
                        destInEdges[j] = True
            if not False in destInEdges:
                break
            else:
                self.__costTolAdder += 0.1 
    #*********** __setEdges     ************    
    
    #*********** __fixVariables  ************

    def __fixVariables(self):
        self.__fixedEdges = []

        for i,j in self.__edges:
            try:
                if self.__model.x[(i,j)].activity == 0:
                    if random.random()<self.__fixVarProb:
                        self.__fixedEdges.append((i,j))
       
            except Exception as err:
                raise PMedianException("Error while fixing vars <" +str( err)+">" )

    def __checkCostTol(self):
        ok = False
        for tol in self.__costTolSrc:
            if tol < 1:
                ok = True
        return ok
    
    #*********** solve ************
    def solve(self, timeLimit=300):

        maxTries = 5
        tries = 1
        sTime=time.time()

        w = CmplSet("sources")
        w.setValues(1,self.__nrOfSources )
        
        c = CmplSet("destinations")
        c.setValues(1,self.__nrOfDestinations)

        self.__costTolAdder = 0
        self.__setCostTol()
        
        p =  CmplParameter("p")    
        p.setValues(self.__p)
               
        #if self.__isCap:
        cap =  CmplParameter("s",w)    
        cap.setValues(self.__supplies)
            
        dem =  CmplParameter("d",c)    
        dem.setValues(self.__demands)

        isCap = CmplParameter('isCap')
        isCap.setValues(self.__isCap)

        isDemandWeighted = CmplParameter("isDemandWeighted")
        isDemandWeighted.setValues(self.__isDemandWeighted)

        edges = CmplSet("edges",2)

        self.__fixedEdges= [(-1,-1)]
        fEdges = CmplSet("fixedEdges",2)
        fEdges.setValues(self.__fixedEdges )

        costs = CmplParameter("c", edges )
     
        isRelaxed = CmplParameter('isRelaxed')

        modelName = self.__home + 'CmplApps'+os.sep+'pMedian.cmpl'
        shutil.copyfile(modelName, self.__tmpPath+"pMedian.cmpl") 
        modelName=self.__tmpPath+"pMedian.cmpl" 

        self.__fixVarProb = 0.4

        while True:
            if self.__nrOfNodes >self.__NodeHeuristics:
                while True:
            
                    self.__setEdges()
                    edges.setValues(self.__edges)
                           
                    costs.setValues(self.__vCosts)

                    isRelaxed.setValues(1)
            
                    self.__model = Cmpl(modelName)

                    self.__model.setSets(fEdges, edges,w,c)
                    self.__model.setParameters(costs, p, cap, dem, isCap, isDemandWeighted, isRelaxed)
            
                    self.__model.debug(self.__debug )
                    #self.__model.setOutput(False)

                    self.__model.setOption('-solver highs')   
                    self.__model.setOption(f'-opt highs time_limit={timeLimit}')  
                    #self.__model.setOption('-solver scip')   
                    #self.__model.setOption(f'-opt scip limits/time={timeLimit}')   

                    print ('    ... solving relaxtion ')
               
                    self.__model.solve()
                              
                    if self.__model.solverStatus == SOLVER_OK:
            
                        self.__totalCosts = self.__model.solution.value
                        self.solutionStatus = True

                        for i in w.values:
                            try:
                                if self.__model.y[i].activity == 1:
                                    self.__activeSources[i-1] = 1
                        
                                for j in c.values:
                                    try:
                                        if self.__model.x[(i,j)].activity > 0:
                                            self.__destSourceMapping[j-1] = i
                                    except:
                                        pass
                            except:
                                pass
                        print(f'        ... Solution has been found with obj value {self.__model.solution.value} ') 
                        break
                    else:
                        print('        ... No solution has been found') 
                        if not self.__checkCostTol():
                            raise PMedianException(' No solution has been found with relaxation - checkTol')
                  
                        self.__costTolAdder += 0.1
                        self.__setCostTol()

                        if (time.time()-sTime) > timeLimit:
                            raise PMedianException(' No solution has been found with relaxation ')

                        #raise PMedianException(' No solution has been found with relaxation')

                print ('    ... solving fixed model ')

                self.__fixVariables()
            
            else:
                self.__edges = []
                self.__vCosts  = []
        
                for i in range(self.__nrOfSources):
                    for j in range(self.__nrOfDestinations):
                        self.__edges.append((i+1,j+1))
                        self.__vCosts.append(self.__varCosts[i][j]) 
                
                edges.setValues(self.__edges)
                costs.setValues(self.__vCosts)
                    
            self.__model = Cmpl(modelName)

            fEdges.setValues(self.__fixedEdges )
            isRelaxed.setValues(0)

            if  (time.time()-sTime) > timeLimit:
                raise PMedianException(' No solution has been found')

            self.__model.setOption('-solver highs')   
            self.__model.setOption(f'-opt highs time_limit={timeLimit}')  
            self.__model.setOption(f'-opt highs mip_rel_gap={self.__mipGap}')  
            #self.__model.setOption('-solver scip')   
            #self.__model.setOption(f'-opt scip limits/time={timeLimit}')   
            
            self.__model.setSets(fEdges, edges,w,c)
            self.__model.setParameters(costs, p, cap, dem, isCap, isDemandWeighted, isRelaxed)

            self.__model.debug(self.__debug )
            self.__model.solve()  
  
            if self.__model.solverStatus == SOLVER_OK:
            
                self.__totalCosts = self.__model.solution.value
                self.solutionStatus = True

                for i in w.values:
                    try:
                        if self.__model.y[i].activity == 1:
                            self.__activeSources[i-1] = 1
                        
                        for j in c.values:
                            try:
                                if self.__model.x[(i,j)].activity > 0:
                                    self.__destSourceMapping[j-1] = i
                            except:
                                pass
                    except:
                        pass
                break
            else:
                print('    ... No solution has been found after '+ str(tries) + ' tries')
                self.__costTolAdder += 0.05
                self.__setCostTol()
                self.__fixVarProb -= 0.1
                if self.__fixVarProb < 0:
                    self.__fixVarProb=0
                tries += 1
                print('    ... new try with fixVarProb: ',self.__fixVarProb ) 
                if tries > maxTries or  (time.time()-sTime) > timeLimit:
                    raise PMedianException(' No solution has been found after '+ str(tries) + ' tries')
                

        
    #*********** end solve **************
    
