#!/etc/env python
import numpy as np
import matplotlib.pyplot as plt
import onnxruntime as onnxrt





# castle_onnx_model: Python class to load the ONNX model and import the model
# ... metadata. Metadata includes input map, reference dimension, and model limits.
# ... Class methods take model parameter inputs as Python **kwargs, process the
# ... parameters into the W/R, A/C, A/T, and R/T format expected by the ONNX model,
# ... and scale the output values as using the input value and reference dimension
# ... value.
class castle_onnx_model(object):
    def __init__(self, modelfilename:str):
        self.modelfilename = modelfilename
        
        self.sess = onnxrt.InferenceSession(
            modelfilename, providers=["CPUExecutionProvider"]
        )
        
        meta = self.sess.get_modelmeta()
        self._parse_param_map(meta.custom_metadata_map['INPUT_MAP'])
        self._parse_reference_dimension(meta.custom_metadata_map['REFERENCE_DIMENSION'])
        self._parse_model_limits(
            meta.custom_metadata_map['INPUT_LIMITS_COEFF_MATRIX'],
            meta.custom_metadata_map['INPUT_LIMITS_VECTOR']
        )
        print(modelfilename +' loaded. Parameters are ' +", ".join(sorted(self.params)))
        pass  #end of __init__
    
    
    
    # _parse_reference_dimension: parse reference dimension string which will be similar
    # ... to 'RAD, 10.0'
    def _parse_reference_dimension(self, refdimstr):
        self.refdim, refdimval = refdimstr.split(',')
        self.refval = float(refdimval)
        pass  #end _parse_reference_dimension
    
    
    
    # _parse_model_limits: parse model limit coefficients and model limit value strings.
    # ... model limits coefficients are stored as an array in column major order: 
    # ... 'r1c1, r2c1, ... rNc1; r1c2, ... rNc2; ... rNcM'. model limit values are
    # ... stored as a vector 'r1lim, r2lim, ... rNlim'. If all values of model limit
    # ... values minus the matrix product of the limit coefficients matrix and model
    # ... are greater than or equal to zero, the model inputs are within the model
    # ... limits.
    def _parse_model_limits(self, coeffmatrixstr, limitvecstr):
        self.limitcoeffs = np.array(eval(
            '[[' + coeffmatrixstr.rstrip(';').replace(';','], [') +']]'
        ), dtype=float).T
        self.limitvec = np.array(eval(
            '[' +limitvecstr +']'
        ), dtype=float)
        pass  #end _parse_model_limits
    
    
    
    # _parse_param_map: stores the unique parameter names, their order, and whether the
    # ... parameter is in the numerator or denominator of the ONNX model input format.
    # ... The mapping is performed only one time when the model is loaded, and the
    # ... stored mapping is then used to automatically transform inputs into the format
    # ... needed by the ONNX model.
    def _parse_param_map(self, inputmap):
        self.inputmap = inputmap
        self.params = []
        self.numerators = []
        self.denominators = []
        for i, p in enumerate(inputmap.split(',')):
            overidx = p.find('over')
            if overidx > 0:
                numparam = p[:overidx].strip()
                denomparam = p[overidx+4:].strip()
            else:
                numparam = p.strip()
                denomparam = ''
                
            if numparam in self.params:
                numidx = self.params.index(numparam)
            else:
                numidx = len(self.params)
                self.params.append(numparam)
            self.numerators.append(numidx)
            
            if denomparam != '':
                if denomparam in self.params:
                    denomidx = self.params.index(denomparam)
                else:
                    denomidx = len(self.params)
                    self.params.append(denomparam)
                self.denominators.append(denomidx)
            else:
                self.denominators.append(-1)
        pass  #end of _parse_param_map
    
    
    
    # _make_input_array: perform the transformation from W, R, T, A, C, Phi inputs
    # ... to [Phi, W/R, A/C, A/T, R/T] required by the ONNX model
    def _make_input_array(self, **kwargs):
        params = np.zeros(len(self.numerators), dtype=np.float32)
        if all([p in kwargs for p in self.params]):
            for i, nidx in enumerate(self.numerators):
                params[i] = kwargs[self.params[nidx]]
            for i, didx in enumerate(self.denominators):
                if didx >= 0: params[i] /= kwargs[self.params[didx]]
        else:
            raise ValueError(
                "Cannot evaluate model: parameters have not all been specified"
            )
        return params  #end of _make_input_array
    
    
    
    # _check_limits: perform the limit check evaluation after parameters are in the 
    # ... required model format
    def _check_limits(self, modelparams):
        return np.all(self.limitvec -np.matmul(self.limitcoeffs, modelparams) >= 0)
    
    
    
    # _calc_sif: perform the sif calculation after parameters are in the required
    # ... model format and the reference dimension scale factor has been calculated
    # ... The ONNX machine learning model is called here.
    def _calc_sif(self, modelparams, refscalefactor):
        return self.sess.run(
            None, {'input': np.atleast_2d(modelparams)}
        )[0] *refscalefactor
    
    
    
    # get_reference_scale_factor: calculate reference dimension scale factor using the
    # ... input values and the reference dimension given in the ONNX model metadata
    def get_reference_scale_factor(self, **kwargs):
        rsf = -1.0
        if self.refdim[0] == 'W':
            rsf = np.sqrt(kwargs['W']/self.refval)
        elif self.refdim[0] == 'T':
            rsf = np.sqrt(kwargs['T']/self.refval)
        elif self.refdim[0] == 'R':
            rsf = np.sqrt(kwargs['R']/self.refval)
        else:
            raise ValueError(
                "Model reference dimension has not been set"
            )
        return rsf
    
    
    
    # check_limits: perform the limit check evaluation starting from unformatted model
    # ... inputs
    def check_limits(self, **kwargs):
        return self._check_limits(self._make_input_array(**kwargs))
    
    
    
    # calc_sif: get parameters transformed into the required ONNX model format, get the
    # ... the reference dimension scale factor, and call the ONNX machine learning model
    def calc_sif(self, **kwargs):
        params = self._make_input_array(**kwargs)
        rsf = self.get_reference_scale_factor(**kwargs)
        return self._calc_sif(params, rsf)
    
    
    
    pass  # end of castle_onnx_model class definition

 



if __name__ == '__main__':
    castlemodel = castle_onnx_model("CAStLE_straightbore_cornercrack_tension.onnx")
    # prints the following output:
    # CAStLE_straightbore_cornercrack_tension.onnx loaded. Parameters are A, C, Phi, R, T, W

    # create a Python dictionary with keys for each listed parameter:
    modelparams = {
        'W': 1.2/2,
        'T': 0.25,
        'R': 0.25,
        'A': 0.05,
        'C': 0.05,
        'Phi': 3*np.pi/180
    }
    # pass parameters to the castle_onnx_model using **modelparams:
    print('Parameters within model limits: ', castlemodel.check_limits(**modelparams))

    # limits for Phi are from 0.052 to 1.52 radians (3 to 87 degrees) - can be deduced
    # ... from the onnx model metadata, but not shown here
    # create array of Phi values from deg2rad(3) to deg2rad(87)
    phiarr = (np.arange(87-3+1) +3) *np.pi/180
    sifarr = np.zeros_like(phiarr)
    # calculate sif values for each Phi value
    for i, phi in enumerate(phiarr):
        modelparams['Phi'] = phi
        sifarr[i] = castlemodel.calc_sif(**modelparams)[0,0]
    
    # save data to CSV format
    np.savetxt(
        'CAStLE_ONNX_Example_Case5Comparison.csv',
        np.column_stack([phiarr, sifarr, 10*sifarr]),
        delimiter=',',
        header='Phi, SIF, K_I_tension'
    )

    try:
        import matplotlib.pyplot as plt
        # plot K_I,tension vs Phi for this crack geometry
        fig = plt.figure()
        fig.suptitle('"ERSI Stress Intensity Comparisons Round Robin" Case 5')
        ax = fig.add_subplot(
            111, xlabel='$2\phi/\pi$', ylabel=r'$K_{t}$', xlim=[0,1], ylim=[4,11.0]
        )
        ax.grid()
        _ = ax.plot(phiarr*2/np.pi, 10*sifarr)
        _ = fig.savefig('CAStLE_ONNX_Example_Case5Comparison.pdf', bbox_inches='tight')
    except ImportError:
        pass