# coding: utf-8

#获取证书名称
def get_cert_data(path,password=None):
    try:
        try:
            from OpenSSL import crypto
            from urllib3.contrib import pyopenssl as reqs
        except :
            os.system(public.get_run_pip('[PIP] install crypto'))
            os.system(public.get_run_pip('[PIP] install pyopenssl'))

            from OpenSSL import crypto
            from urllib3.contrib import pyopenssl as reqs

        certPath = path
        if os.path.exists(certPath):
            data = {}

            f = open(certPath,'rb')
            pfx_buffer = f.read()
            cret_data = pfx_buffer
            if certPath[-4:] == '.pfx':                              
                if password:
                    try:
                        p12 = crypto.load_pkcs12(pfx_buffer,password)
                    except :
                        p12 = crypto.load_pkcs12(pfx_buffer)
                else:
                    p12 = crypto.load_pkcs12(pfx_buffer)
            
                x509 = p12.get_certificate()
                data['type'] = 'pfx'
            else:
                #验证域名证书和跟证书顺序
                try:
                    cert_data = split_ca_data(public.readFile(certPath))
                    cert_d1 = get_cert_info(cert_data['cert'])
                    cert_d2 = get_cert_info(cert_data['ca_data'])
                    if cert_d1['unix_timeout'] > cert_d2['unix_timeout']:
                        public.writeFile(certPath,cert_data['ca_data'] + cert_data['cert'])
                        f = open(certPath,'rb')
                        pfx_buffer = f.read()
                except :pass

                x509 = crypto.load_certificate(crypto.FILETYPE_PEM, pfx_buffer)
                data['type'] = 'pem'
            buffs = x509.digest('sha1')
            data['hash'] =  bytes.decode(buffs).replace(':','')
            data['number'] = x509.get_serial_number()
            issuser = x509.get_issuer()

            is_key = 'O'
            if len(issuser.get_components()) == 1: is_key = 'CN'
            for item in issuser.get_components():
                if bytes.decode(item[0]) == is_key:
                    data['issuer'] = bytes.decode(item[1])
                    break

            data['notAfter'] = strfToTime(bytes.decode( x509.get_notAfter())[:-1])
            data['notBefore'] = strfToTime(bytes.decode(x509.get_notBefore())[:-1])
            data['version'] = x509.get_version()
            data['timeout'] = x509.has_expired()

            data['endtime'] = get_ssl_endtime(data['notAfter'])
            x509name = x509.get_subject()
            data['subject'] = x509name.commonName.replace('*','_')
            data['dns'] = []
            alts = reqs.get_subj_alt_name(x509)
            for x in alts:
                data['dns'].append(x[1])

        data['endtime'] = int(int(time.mktime(time.strptime(data['notAfter'], "%Y-%m-%d")) - time.time()) / 86400)
        return data
    except:
        #print(public.get_error_info())
        return None

 #拆分根证书
    def split_ca_data(cert):
        datas = cert.split('-----END CERTIFICATE-----')
        return {"cert":datas[0] + "-----END CERTIFICATE-----\n","ca_data":datas[1] + '-----END CERTIFICATE-----\n' }


#转换时间
def strfToTime(sdate):
    import time
    return time.strftime('%Y-%m-%d',time.strptime(sdate,'%Y%m%d%H%M%S'))

def get_ssl_endtime(endtime):
    endtime = time.mktime(time.strptime(endtime, "%Y-%m-%d"))
    return int((endtime - time.time())/86400)


#获取证书哈希
def get_cert_info(cret_data):

    from OpenSSL import crypto
    x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cret_data)

    buffs = x509.digest('sha1')
    hash =  bytes.decode(buffs).replace(':','')
    data = {}
    data['hash'] = hash
    data['timeout'] = strfToTime(bytes.decode(x509.get_notAfter())[:-1])
    data['unix_timeout'] = get_unixtime(data['timeout'],"%Y-%m-%d")

    return data

      
def get_root_domain(domain_name):
    '''
        @name 根据域名查询根域名和记录值
        @author cjxin<2020-12-17>
        @param domain {string} 被验证的根域名
        @return void
    '''
    top_domain_list = ['.ac.cn', '.ah.cn', '.bj.cn', '.com.cn', '.cq.cn', '.fj.cn', '.gd.cn',
                       '.gov.cn', '.gs.cn', '.gx.cn', '.gz.cn', '.ha.cn', '.hb.cn', '.he.cn',
                       '.hi.cn', '.hk.cn', '.hl.cn', '.hn.cn', '.jl.cn', '.js.cn', '.jx.cn',
                       '.ln.cn', '.mo.cn', '.net.cn', '.nm.cn', '.nx.cn', '.org.cn', '.cn.com','.edu.cn']
    old_domain_name = domain_name
    top_domain = "." + ".".join(domain_name.rsplit('.')[-2:])
    new_top_domain = "." + top_domain.replace(".", "")
    is_tow_top = False
    if top_domain in top_domain_list:
        is_tow_top = True
        domain_name = domain_name[:-len(top_domain)] + new_top_domain

    if domain_name.count(".") > 1:
        zone, middle, last = domain_name.rsplit(".", 2)
        if is_tow_top:
            last = top_domain[1:]
        root = ".".join([middle, last])
    else:
        zone = ""
        root = old_domain_name
    return root, zone

def query_dns(domain, dns_type='A', is_root=False):
    try:
        import dns.resolver
    except:
        os.system(get_run_pip('[PIP] install dnspython'))
        import dns.resolver

    if is_root: domain, zone = get_root_domain(domain)
    try:

        ret = dns.resolver.query(domain, dns_type)
        data = []
        for i in ret.response.answer:
            for j in i.items:
                tmp = {}
                if dns_type == 'A':
                    tmp['value'] = j.address
                elif dns_type == 'CNAME':
                    tmp['value'] = j.to_text()
                elif dns_type == 'NS':
                    tmp['value'] = j.to_text()
                elif dns_type == 'CAA':
                    tmp['flags'] = j.flags
                    tmp['tag'] = j.tag.decode()
                    tmp['value'] = j.value.decode()
                data.append(tmp)
        return data
    except:
        return False