/**
 * Copyright (c) 2025, Fabian Groffen. All rights reserved.
 *
 * See LICENSE for the license.
 */

#include <stdio.h>
#include <unistd.h>
#include <strings.h>
#include <stdbool.h>
#include <stdlib.h>
#include <limits.h>

#include <ldns/ldns.h>

#include "util.h"

static int
host_print_record
(
    ldns_rr_list *ans
)
{
    size_t i;
    size_t j;
    int    ret = EXIT_FAILURE;

    for (i = 0; i < ldns_rr_list_rr_count(ans); i++)
    {
        ldns_rr      *rr   = ldns_rr_list_rr(ans, i);
        ldns_rdf     *base = ldns_rr_owner(rr);
        ldns_rr_type  tpe  = ldns_rr_get_type(rr);
        char         *own;
        char         *trg;
        uint32_t      prio;

        /* drop trailing dot */
        ldns_rdf_set_size(base, ldns_rdf_size(base) - 1);
        own = ldns_rdf2str(base);

        if (tpe == LDNS_RR_TYPE_A ||
            tpe == LDNS_RR_TYPE_AAAA)
        {
            ldns_rdf *targ = ldns_rr_rdf(rr, 0);
            char     *dest = ldns_rdf2str(targ);

            fprintf(stdout, "%s has %saddress %s\n",
                    own, tpe == LDNS_RR_TYPE_AAAA ? "IPv6 " : "", dest);
            ret = EXIT_SUCCESS;

            free(dest);
        }
        else if (tpe == LDNS_RR_TYPE_CNAME)
        {
            ldns_rdf *targ = ldns_rr_rdf(rr, 0);
            char     *dest = ldns_rdf2str(targ);

            fprintf(stdout, "%s is an alias for %s\n", own, dest);
            ret = EXIT_SUCCESS;

            free(dest);
        }
        else if (tpe == LDNS_RR_TYPE_MX)
        {
            trg = NULL;
            for (j = 0; j < ldns_rr_rd_count(rr); j++)
            {
                ldns_rdf      *targ = ldns_rr_rdf(rr, j);
                ldns_rdf_type  t    = ldns_rdf_get_type(targ);

                switch (t)
                {
                    case LDNS_RDF_TYPE_DNAME:
                        trg = ldns_rdf2str(targ);
                        break;
                    case LDNS_RDF_TYPE_INT8:
                        prio = (uint32_t)ldns_rdf2native_int8(targ);
                        break;
                    case LDNS_RDF_TYPE_INT16:
                        prio = (uint32_t)ldns_rdf2native_int16(targ);
                        break;
                    case LDNS_RDF_TYPE_INT32:
                        prio = ldns_rdf2native_int32(targ);
                        break;
                    default:
                        /* ignore this RD */
                        break;
                }
            }
            if (trg != NULL)
            {
                fprintf(stdout, "%s mail is handled by %u %s\n",
                        own, prio, trg);
                ret = EXIT_SUCCESS;
            }

            free(trg);
        }
        else if (tpe == LDNS_RR_TYPE_PTR)
        {
            ldns_rdf *targ = ldns_rr_rdf(rr, 0);
            char     *dest = ldns_rdf2str(targ);

            fprintf(stdout, "%s domain name pointer %s\n", own, dest);
            ret = EXIT_SUCCESS;

            free(dest);
        }
        else if (tpe == LDNS_RR_TYPE_TXT)
        {
            ldns_rdf *targ = ldns_rr_rdf(rr, 0);
            char     *dest = ldns_rdf2str(targ);

            fprintf(stdout, "%s descriptive text %s\n", own, dest);
            ret = EXIT_SUCCESS;

            free(dest);
        }
        else if (tpe == LDNS_RR_TYPE_NS)
        {
            ldns_rdf *targ = ldns_rr_rdf(rr, 0);
            char     *dest = ldns_rdf2str(targ);

            fprintf(stdout, "%s name server %s\n", own, dest);
            ret = EXIT_SUCCESS;

            free(dest);
        }
        else if (tpe == LDNS_RR_TYPE_SOA)
        {
            ldns_rdf *targ = ldns_rr_rdf(rr, 0);
            char     *dest = ldns_rdf2str(targ);

            fprintf(stdout, "%s has SOA record %s", own, dest);
            free(dest);
            for (j = 1; j < ldns_rr_rd_count(rr); j++)
            {
                ldns_rdf_type  t;

                targ = ldns_rr_rdf(rr, j);
                t    = ldns_rdf_get_type(targ);

                switch (t)
                {
                    case LDNS_RDF_TYPE_DNAME:
                    case LDNS_RDF_TYPE_STR:
                        dest = ldns_rdf2str(targ);
                        fprintf(stdout, " %s", dest);
                        free(dest);
                        break;
                    case LDNS_RDF_TYPE_INT8:
                        fprintf(stdout, " %u",
                                (uint32_t)ldns_rdf2native_int8(targ));
                        break;
                    case LDNS_RDF_TYPE_INT16:
                        fprintf(stdout, " %u",
                                (uint32_t)ldns_rdf2native_int16(targ));
                        break;
                    case LDNS_RDF_TYPE_INT32:
                    case LDNS_RDF_TYPE_PERIOD:
                        fprintf(stdout, " %u",
                                ldns_rdf2native_int32(targ));
                        break;
                    default:
                        /* ignore this RD */
                        break;
                }
            }
            fprintf(stdout, "\n");

            ret = EXIT_SUCCESS;
        }

        free(own);
    }

    return ret;
}

static int
host_func
(
    ldns_resolver *res,
    ldns_rdf      *domain,
    ldns_rr_type  *types,
    ldns_rr_class  qclass,
    bool           verbose,
    bool           implicit
)
{
    ldns_pkt      *pkt;
    ldns_rr_list  *ans;
    ldns_rr_type   tpe;
    ldns_rr_type   lasttpe = (ldns_rr_type)0;
    bool           found   = false;

    while ((uint32_t)(tpe = *types++) != 0)
    {
        pkt = ldns_resolver_search(res, domain, tpe, qclass, LDNS_RD);

        if (pkt == NULL)
        {
            fprintf(stderr,
                    ";; connection timed out; no servers could be reached\n");

            return EXIT_FAILURE;
        }
        else if (ldns_pkt_reply_type(pkt) != LDNS_PACKET_ANSWER &&
                 ldns_pkt_reply_type(pkt) != LDNS_PACKET_NODATA)
        {
            char           *dom;
            char           *rcd;
            ldns_pkt_rcode  rc;

            rc = LDNS_RCODE_NXDOMAIN;
            if (pkt != NULL)
                rc = ldns_pkt_get_rcode(pkt);

            if (ldns_dname_absolute(domain))
                ldns_rdf_set_size(domain, ldns_rdf_size(domain) - 1);
            dom = ldns_rdf2str(domain);
            rcd = ldns_pkt_rcode2str(rc);

            fprintf(stderr, "Host %s not found: %u(%s)\n",
                    dom, (uint32_t)rc, rcd);

            free(dom);
            free(rcd);
            if (pkt != NULL)
                ldns_pkt_free(pkt);

            return EXIT_FAILURE;
        }

        if (verbose)
        {
            char *digd = ldns_pkt2str(pkt);
            fprintf(stdout, "%s\n", digd);
            free(digd);
            found = true;
        }
        else
        {
            ans = ldns_pkt_rr_list_by_type(pkt, tpe, LDNS_SECTION_ANSWER);
            if (tpe == LDNS_RR_TYPE_MX)
                ldns_rr_list_sort(ans);
            if (host_print_record(ans) == EXIT_SUCCESS)
                found = true;
            ldns_rr_list_deep_free(ans);
        }

        ldns_pkt_free(pkt);

        lasttpe = tpe;
    }

    if (!implicit &&
        !found &&
        lasttpe != (ldns_rr_type)0)
    {
        char *dom;
        char *rr;

        if (ldns_dname_absolute(domain))
            ldns_rdf_set_size(domain, ldns_rdf_size(domain) - 1);
        dom = ldns_rdf2str(domain);

        /* take last type to report on */
        rr  = ldns_rr_type2str(lasttpe);

        fprintf(stderr, "%s has no %s record\n", dom, rr);

        free(dom);
        free(rr);

        return EXIT_FAILURE;
    }

    return EXIT_SUCCESS;
}

static int
host_afxr
(
    ldns_resolver *res,
    ldns_rdf      *domain,
    ldns_rr_type  *types,
    ldns_rr_class  qclass,
    bool           verbose
)
{
    ldns_rr_list *ans;

    if (ldns_axfr_start(res, domain, qclass) != LDNS_STATUS_OK)
        return EXIT_FAILURE;

    ans = ldns_rr_list_new();

    while (ldns_axfr_complete(res) == 0)
    {
        ldns_rr        *rr  = ldns_axfr_next(res);
        ldns_rr_type   tpe  = ldns_rr_get_type(rr);
        ldns_rr_type  *tpew = NULL;

        if (rr == NULL)
        {
            char           *dom;
            char           *rcd;
            ldns_pkt       *pkt;
            ldns_pkt_rcode  rc;

            pkt = ldns_axfr_last_pkt(res);

            rc = LDNS_RCODE_NOTAUTH;
            if (pkt != NULL)
                rc = ldns_pkt_get_rcode(pkt);

            dom = ldns_rdf2str(domain);
            rcd = ldns_pkt_rcode2str(rc);

            fprintf(stderr, "Host %s not found: %u(%s)\n",
                    dom, (uint32_t)rc, rcd);

            free(dom);
            free(rcd);
            if (pkt != NULL)
                ldns_pkt_free(pkt);
            ldns_rr_list_free(ans);

            return EXIT_FAILURE;
        }

        for (tpew = types; (uint32_t)*tpew != 0; tpew++)
        {
            if (*tpew == tpe)
            {
                ldns_rr_list_push_rr(ans, rr);
                break;
            }
        }
    }

    if (verbose)
    {
        size_t i;
        for (i = 0; i < ldns_rr_list_rr_count(ans); i++)
        {
            char *digd = ldns_rr2str(ldns_rr_list_rr(ans, i));
            fprintf(stdout, "%s\n", digd);
            free(digd);
        }
    }
    else
    {
        host_print_record(ans);
    }

    ldns_rr_list_free(ans);

    return EXIT_SUCCESS;
}

static void
do_usage
(
    FILE *dst
)
{
    fprintf(dst,
"Usage: host [-aCdilrTvVw] [-c class] [-N ndots] [-t type] [-W time]\n"
"            [-R number] [-m flag] [-p port] hostname [server]\n"
"    -a is equivalent to -v -t ANY\n"
"    -A is like -a but omits RRSIG, NSEC, NSEC3\n"
"    -c specifies query class for non-IN data\n"
"    -d is equivalent to -v\n"
"    -i reverse lookups of IPv6 addresses use ip6.int domain\n"
"    -l lists all hosts in a domain, using AXFR\n"
"    -N changes the number of dots allowed before root lookup is done\n"
"    -p specifies the port on the server to query\n"
"    -r disables recursive processing\n"
"    -R specifies number of retries for UDP packets\n"
"    -s a SERVFAIL response should stop query\n"
"    -t specifies the query type\n"
"    -T enables TCP/IP mode\n"
"    -U enables UDP mode\n"
"    -v enables verbose output\n"
"    -V print version number and exit\n"
"    -w specifies to wait forever for a reply\n"
"    -W specifies how long to wait for a reply\n"
"    -4 use IPv4 query transport only\n"
"    -6 use IPv6 query transport only\n"
);
}

int
main(int argc, char **argv)
{
    ldns_resolver *res;
    ldns_rdf      *domain;
    ldns_rr_type  *qtypes   = NULL;
    ldns_status    s;
    ldns_rr_class  qclass   = LDNS_RR_CLASS_IN;
    int            opt;
    char          *typestr  = NULL;
    char          *classstr = NULL;
    uint8_t        ndots    = 1;
    uint8_t        nretries = 1;
    uint8_t        ipmode   = 0;  /* no preference */
    uint8_t        timeout  = 1;  /* second */
    uint8_t        tcpudp   = 0;  /* no preference */
    uint16_t       port     = 0;
    bool           verbose  = false;
    bool           doafxr   = false;
    bool           doip6int = false;  /* stick with ip6.arpa */
    bool           recurse  = true;
    bool           failfast = false;  /* try all nameservers */

    while ((opt = getopt(argc, argv, "46aAc:dilN:p:srR:t:TUvVwW:")) != -1)
    {
        switch (opt) {
            case '4':
                ipmode = 1;
                break;
            case '6':
                ipmode = 2;
                break;
            case 'a':
            case 'A':  /* should be a - [RRSIG, NSEC, NSEC3] */
                typestr = "any";
                /* fall through */
            case 'd':
            case 'v':
                verbose = true;
                break;
            case 'c':
                classstr = optarg;
                break;
            case 'i':
                doip6int = true;
                break;
            case 'l':
                doafxr = true;
                break;
            case 'N':
                {
                    char *retp;
                    long  val;

                    val = strtol(optarg, &retp, 10);
                    if (retp == optarg ||
                        *retp != '\0' ||
                        val < 0)
                    {
                        fprintf(stderr,
                                "host: invalid argument for -N: %s\n",
                                optarg);
                        exit(EXIT_FAILURE);
                    }
                    if (val > UCHAR_MAX)
                    {
                        fprintf(stderr,
                                "host: value for -N too large: %ld\n",
                                val);
                        exit(EXIT_FAILURE);
                    }

                    ndots = (uint8_t)val;
                }
                break;
            case 'p':
                {
                    char *retp;
                    long  val;

                    val = strtol(optarg, &retp, 10);
                    if (retp == optarg ||
                        *retp != '\0' ||
                        val <= 0)
                    {
                        fprintf(stderr,
                                "host: invalid argument for -p: %s\n",
                                optarg);
                        exit(EXIT_FAILURE);
                    }
                    if (val > USHRT_MAX)
                    {
                        fprintf(stderr,
                                "host: value for -p too large: %ld\n",
                                val);
                        exit(EXIT_FAILURE);
                    }

                    port = (uint16_t)val;
                }
                break;
            case 'r':
                recurse = false;
                break;
            case 'R':
                {
                    char *retp;
                    long  val;

                    val = strtol(optarg, &retp, 10);
                    if (retp == optarg ||
                        *retp != '\0')
                    {
                        fprintf(stderr,
                                "host: invalid argument for -R: %s\n",
                                optarg);
                        exit(EXIT_FAILURE);
                    }
                    if (val > UCHAR_MAX)
                    {
                        fprintf(stderr,
                                "host: value for -R too large: %ld\n",
                                val);
                        exit(EXIT_FAILURE);
                    }

                    if (val <= 0)
                        nretries = 1;
                    else
                        nretries = (uint8_t)val;
                }
                break;
            case 's':
                failfast = true;
                break;
            case 't':
                typestr = optarg;
                break;
            case 'U':
                tcpudp = 2;
                break;
            case 'T':
                tcpudp = 1;
                break;
            case 'V':
                fprintf(stdout, "host %s from ldns-tools (using ldns %s)\n",
                        LDNS_TOOLS_VERSION, ldns_version());
                exit(EXIT_SUCCESS);
            case 'w':
                timeout = 0;  /* forever */
                break;
            case 'W':
                {
                    char *retp;
                    long  val;

                    val = strtol(optarg, &retp, 10);
                    if (retp == optarg ||
                        *retp != '\0')
                    {
                        fprintf(stderr,
                                "host: invalid argument for -W: %s\n",
                                optarg);
                        exit(EXIT_FAILURE);
                    }
                    if (val > UCHAR_MAX)
                    {
                        fprintf(stderr,
                                "host: value for -W too large: %ld\n",
                                val);
                        exit(EXIT_FAILURE);
                    }

                    if (val <= 0)
                        timeout = 1;
                    else
                        timeout = (uint8_t)val;
                }
                break;
            default:
                do_usage(stderr);
                exit(EXIT_FAILURE);
        }
    }

    if (optind >= argc)
    {
        do_usage(stdout);
        exit(EXIT_FAILURE);
    }

    if (typestr != NULL)
    {
        ldns_rr_type rrtype = ldns_get_rr_type_by_name(typestr);

        switch (rrtype)
        {
            case (ldns_rr_type)0:
                fprintf(stderr, "host: invalid type: %s\n", typestr);
                exit(EXIT_FAILURE);
            case LDNS_RR_TYPE_A:
            case LDNS_RR_TYPE_AAAA:
            case LDNS_RR_TYPE_TXT:
                qtypes = calloc(sizeof(qtypes[0]), 3);
                qtypes[0] = LDNS_RR_TYPE_CNAME;
                qtypes[1] = rrtype;
                break;
            case LDNS_RR_TYPE_ANY:
                qtypes = calloc(sizeof(qtypes[0]), 8);
                qtypes[0] = LDNS_RR_TYPE_SOA;
                qtypes[1] = LDNS_RR_TYPE_NS;
                qtypes[2] = LDNS_RR_TYPE_CNAME;
                qtypes[3] = LDNS_RR_TYPE_A;
                qtypes[4] = LDNS_RR_TYPE_AAAA;
                qtypes[5] = LDNS_RR_TYPE_MX;
                qtypes[6] = LDNS_RR_TYPE_TXT;
                break;
            default:
                qtypes = calloc(sizeof(qtypes[0]), 2);
                qtypes[0] = rrtype;
                break;
        }
    }

    if (classstr != NULL)
    {
        if (strcasecmp(classstr, "in") == 0)
            qclass = LDNS_RR_CLASS_IN;
        else if (strcasecmp(classstr, "ch") == 0)
            qclass = LDNS_RR_CLASS_CH;
        else if (strcasecmp(classstr, "hs") == 0)
            qclass = LDNS_RR_CLASS_HS;
        else
        {
            fprintf(stderr, "host: invalid class: %s\n", classstr);
            exit(EXIT_FAILURE);
        }
    }

    /* create resolver from /etc/resolv.conf */
    s = ldns_resolver_new_frm_file(&res, NULL);

    /* take second argument as DNS server to use */
    if (argc - optind > 1)
    {
        ldns_rdf **dns = util_addr_frm_str(res, argv[optind + 1], ndots);

        if (dns == NULL)
        {
            if (s == LDNS_STATUS_OK)
            {
                fprintf(stderr,
                        "host: couldn't get address for '%s': not found\n",
                        argv[optind + 1]);
                exit(EXIT_FAILURE);
            }
            /* else handled below */
        }
        else
        {
            ldns_rdf *swalk;
            size_t    i;

            if (s == LDNS_STATUS_OK)
                ldns_resolver_deep_free(res);

            res = ldns_resolver_new();
            for (i = 0; (swalk = dns[i]) != NULL; i++)
            {
                ldns_resolver_push_nameserver(res, swalk);
                ldns_rdf_deep_free(swalk);
            }
            free(dns);

            /* in case /etc/resolv.conf could not be read and an IP was
             * given, ensure the check below won't fire */
            s = LDNS_STATUS_OK;
        }
    }

    if (s != LDNS_STATUS_OK)
    {
        fprintf(stderr, "host: could not open/parse /etc/resolv.conf\n");
        exit(EXIT_FAILURE);
    }

    if (port > 0)
        ldns_resolver_set_port(res, port);
    ldns_resolver_set_recursive(res, recurse);
    ldns_resolver_set_retry(res, nretries);
    ldns_resolver_set_ip6(res, ipmode);
    {
        struct timeval tvtimeout = {
            .tv_sec = timeout == 0 ? INT_MAX : timeout
        };
        ldns_resolver_set_timeout(res, tvtimeout);
        ldns_resolver_set_retrans(res, timeout);
    }
    if (tcpudp > 0)
        ldns_resolver_set_usevc(res, tcpudp == 1 ? true : false);
    ldns_resolver_set_fail(res, failfast);

    s = ldns_str2rdf_a(&domain, argv[optind]);
    if (s != LDNS_STATUS_OK)
        s = ldns_str2rdf_aaaa(&domain, argv[optind]);
    if (s == LDNS_STATUS_OK)
    {
        if (!doafxr &&
            qtypes == NULL)
        {
            qtypes = calloc(sizeof(qtypes[0]), 2);
            qtypes[0] = LDNS_RR_TYPE_PTR;
        }

        domain = util_addr2dname(domain, doip6int);
    }
    else
    {
        domain = util_dname_frm_str(argv[optind], ndots, doip6int);
    }

    if (verbose)
    {
        char *dom = ldns_rdf2str(domain);
        fprintf(stdout, "Trying \"%s\"\n", dom);
        free(dom);
    }
    if (doafxr)
    {
        if (qtypes == NULL)
        {
            qtypes = calloc(sizeof(qtypes[0]), 5);
            qtypes[0] = LDNS_RR_TYPE_NS;
            qtypes[1] = LDNS_RR_TYPE_A;
            qtypes[2] = LDNS_RR_TYPE_AAAA;
            qtypes[3] = LDNS_RR_TYPE_PTR;
        }
        host_afxr(res, domain, qtypes, qclass, verbose);
    }
    else
    {
        bool implicit = false;

        if (qtypes == NULL)
        {
            qtypes = calloc(sizeof(qtypes[0]), 5);
            qtypes[0] = LDNS_RR_TYPE_CNAME;
            qtypes[1] = LDNS_RR_TYPE_A;
            qtypes[2] = LDNS_RR_TYPE_AAAA;
            qtypes[3] = LDNS_RR_TYPE_MX;
            implicit  = true;
        }
        host_func(res, domain, qtypes, qclass, verbose, implicit);
    }

    free(qtypes);
    ldns_rdf_deep_free(domain);
    ldns_resolver_deep_free(res);

    return EXIT_SUCCESS;
}

/* vim: set ts=4 sw=4 expandtab cinoptions=(0,u0,U1,W2s,l1: */
