/*
  This file is a part of Qosmos ixEngine.

   Copyright  Qosmos 2022 - All rights reserved

  This computer program and all its components are protected by
  authors' rights and copyright law and by international treaties.
  Any representation, reproduction, distribution or modification
  of this program or any portion of it is forbidden without
  Qosmos explicit and written agreement and may result in severe
  civil and criminal penalties, and will be prosecuted
  to the maximum extent possible under the law.
*/

#include <vppinfra/dlist.h>
#include <vppinfra/error.h>
#include <vppinfra/types.h>
#include <vppinfra/vec.h>
#include <vnet/ip/ip4_packet.h>
#include <vnet/tcp/tcp_packet.h>
#include <vnet/udp/udp_packet.h>
#include <vnet/pg/pg.h>
#include <vnet/plugin/plugin.h>
#include <vnet/vnet.h>
#include <vlib/vlib.h>
#include <dpi/dpi.h>
#include <qmdpi.h>
#include <dpi/protodef.h>
#include <flowtable/flowdata.h>
#include <flowtable/flowtable.h>
#include <arpa/inet.h>

#if (defined DEBUG) && (DEBUG >= 1)
#define DBG_PRINT(args ...) printf(args)
#else
#define DBG_PRINT(args ...)
#endif

extern dpi_main_t dpi_main;

typedef struct {
    u32 next_index;
    u32 sw_if_index;
    u8  new_src_mac[6];
    u8  new_dst_mac[6];
} dpi_trace_t;

#ifndef CLIB_MARCH_VARIANT
static u8 *my_format_mac_address(u8 *s, va_list *args)
{
    u8 *a = va_arg(*args, u8 *);
    return format(s, "%02x:%02x:%02x:%02x:%02x:%02x",
                  a[0], a[1], a[2], a[3], a[4], a[5]);
}

/* packet trace format function */
static u8 *format_dpi_trace(u8 *s, va_list *args)
{
    CLIB_UNUSED(vlib_main_t *vm) = va_arg(*args, vlib_main_t *);
    CLIB_UNUSED(vlib_node_t *node) = va_arg(*args, vlib_node_t *);
    dpi_trace_t *t = va_arg(*args, dpi_trace_t *);

    s = format(s, "DPI: sw_if_index %d, next index %d\n",
               t->sw_if_index, t->next_index);
    s = format(s, "  new src %U -> new dst %U",
               my_format_mac_address, t->new_src_mac,
               my_format_mac_address, t->new_dst_mac);
    return s;
}

vlib_node_registration_t dpi_node;

#endif /* CLIB_MARCH_VARIANT */

#define foreach_dpi_error \
    _(PACKETS, "Packets submitted for classification") \
    _(UNHANDLED, "Unhandled (non-ip) packets") \
    _(FAILED, "DPI handling failed") \
    _(CLASSIFIED, "Flows classified") \
    _(VOLUME, "Volume of traffic (in bytes)")

typedef enum {
#define _(sym,str) DPI_ERROR_##sym,
    foreach_dpi_error
#undef _
    DPI_N_ERROR,
} dpi_error_t;

#ifndef CLIB_MARCH_VARIANT
static char *dpi_error_strings[] = {
#define _(sym,string) string,
    foreach_dpi_error
#undef _
};
#endif /* CLIB_MARCH_VARIANT */

#define DPI_COUNTER_INC(_vm, _sym, _val) \
    if (_val) \
        vlib_node_increment_counter(_vm, dpi_node.index, \
                                    DPI_ERROR_ ## _sym, _val)

static void classification_result_manage(vlib_main_t         *vm,
                                         struct qmdpi_worker *worker,
                                         struct qmdpi_flow   *df,
                                         struct qmdpi_result *result,
                                         flow_signature_t    *sig)
{
    struct qmdpi_result_flags const *result_flags;
    struct qmdpi_flow   *flow;
    struct qmdpi_bundle *bndl;
    struct qmdpi_path   *path;
    dpi_main_t          *dpi = &dpi_main;

    u32 CPT_CLASSIFIED = 0;

    result_flags = qmdpi_result_flags_get(result);
    if (QMDPI_RESULT_FLAGS_CLASSIFIED_STATE(result_flags)
            && QMDPI_RESULT_FLAGS_CLASSIFIED_STATE_CHANGED(result_flags)) {

        flow = qmdpi_result_flow_get(result);
        bndl = qmdpi_flow_bundle_get(flow);
        path = qmdpi_result_path_get(result);

        CPT_CLASSIFIED ++;

        char buffer[200];
        qmdpi_data_path_to_buffer(bndl, buffer, sizeof(buffer), path);
        dpi_print(vm, "[DPI] flow classified as %s\n", buffer);
    }

    DPI_COUNTER_INC(vm, CLASSIFIED, CPT_CLASSIFIED);
}

static void classification_process(vlib_main_t          *vm,
                                   struct qmdpi_worker  *worker,
                                   struct qmdpi_flow    *df,
                                   const u_char         *data,
                                   u32                   length,
                                   const struct timeval *ts,
                                   const void           *l3hdr,
                                   const void           *l4hdr,
                                   int                   dir,
                                   flow_signature_t     *sig)
{
    struct qmdpi_result *result;
    dpi_main_t          *dpi = &dpi_main;
    int                  ret;

    /* Set PDU information to be processed by the worker */
    if (qmdpi_worker_pdu_set(worker, data, length, ts, QMDPI_PROTO_ETH, dir,
                             0) != 0) {
        DPI_COUNTER_INC(vm, FAILED, 1);
        return;
    }

    /* Set L3 and L4 headers */
    if (qmdpi_worker_pdu_header_set(worker, (void *) l3hdr, (void *) l4hdr) != 0) {
        DPI_COUNTER_INC(vm, FAILED, 1);
        return;
    }

    /* Process packet with worker and provide DPI result */
    do {
        ret = qmdpi_worker_process(worker, df, &result);
        if (ret < 0) {
            DPI_COUNTER_INC(vm, FAILED, 1);
        }
#if (defined DEBUG) && (DEBUG >= 1)
        else if (dpi->log_enable) {
            /* play with result */
            classification_result_manage(vm, worker, df, result, sig);
        }
#endif /* DEBUG */
    } while (ret == QMDPI_PROCESS_MORE);
}

static void process_dpi(vlib_main_t             *vm,
                        struct qmdpi_worker     *worker,
                        struct qmdpi_flow       *df,
                        vlib_buffer_t           *b,
                        packet_signature_t      *pkt_sig,
                        flow_signature_t        *sig,
                        u32                      current_time,
                        int                      is_reverse)
{
    const void *l4hdr = pkt_sig->l4hdr;

    if (df && (b->current_length > sizeof(ethernet_header_t))
            && ((sig->len == sizeof(struct ip4_sig)) ||
                (sig->len == sizeof(struct ip6_sig)))) {
        if (l4hdr) {
            /* Case where we have TCP or UDP over IPv4 or IPv6 */
            const u_char *data  = vlib_buffer_get_current(b) + sizeof(ethernet_header_t);
            int avl_len         = b->current_length - sizeof(ethernet_header_t);
            ip4_header_t *iphdr = (ip4_header_t *)
                                  data;//b should start with the IP layer. The compute hash ensures this, by advancing with ETH header size
            struct timeval ts;
            ts.tv_usec          = 0;
            ts.tv_sec           = current_time;

            if (avl_len > pkt_sig->data_offset) {
                /* Get pointer and length of payload */
                data = (const u_char *)iphdr + pkt_sig->data_offset;
                avl_len -= pkt_sig->data_offset;

                classification_process(vm,
                                       worker,
                                       df,
                                       data,
                                       avl_len,
                                       &ts,
                                       iphdr,
                                       l4hdr,
                                       (is_reverse) ? QMDPI_DIR_STC : QMDPI_DIR_CTS,
                                       sig);
            }
        }
    }
}

static int unpack_sigs_from_opaque(vlib_buffer_t         *b,
                                   int                   *is_reverse,
                                   packet_signature_t    *pkt_sig,
                                   flow_signature_t      *sig,
                                   struct qmdpi_flow     **df)
{
    sig_info_t      *info = (sig_info_t *) vnet_buffer2(b)->unused;

    *is_reverse = info->is_reverse;

    pkt_sig->data_offset = info->data_offset;
    pkt_sig->l3hdr = NULL; /* l3hdr will point in the packet,
                              just after the Ethernet header */
    pkt_sig->l4hdr = info->l4hdr;
    flow_entry_t *f = (flow_entry_t *) info->ft_flow;
    if (f->infos.data.dpi_flow == NULL) {
        return -1;
    }
    *df = f->infos.data.dpi_flow;

    if (info->is_ipv6) {
        sig->len          = sizeof(struct ip6_sig);
        sig->s.ip6.proto    = info->ip_protocol;
        sig->s.ip6.port_src = info->src_port;
        sig->s.ip6.port_dst = info->dst_port;

        u8 *ip6hdr = (u8 *) vlib_buffer_get_current(b);
        ip6hdr += sizeof(ethernet_header_t);
        memcpy(&sig->s.ip6.src,
               ip6hdr + info->ipv6.src_offset,
               16);
        memcpy(&sig->s.ip6.dst,
               ip6hdr + info->ipv6.dst_offset,
               16);
    } else {
        sig->len          = sizeof(struct ip4_sig);
        sig->s.ip4.proto    = info->ip_protocol;
        sig->s.ip4.src      = info->ipv4.src;
        sig->s.ip4.dst      = info->ipv4.dst;
        sig->s.ip4.port_src = info->src_port;
        sig->s.ip4.port_dst = info->dst_port;
    }

    return 0;
}

static void process_dpi_data(vlib_main_t             *vm,
                             struct qmdpi_worker     *worker,
                             vlib_buffer_t           *b,
                             u32                      current_time)
{
    int                 is_reverse;
    packet_signature_t  pkt_sig;
    flow_signature_t    sig;
    struct qmdpi_flow   *df = NULL;

    /* Skip the ethernet header. We only manage IP */
    ethernet_header_t *eth = (ethernet_header_t *) vlib_buffer_get_current(b);
    if (PREDICT_TRUE(eth->type == clib_host_to_net_u16(ETHERNET_TYPE_IP6)
                     || eth->type == clib_host_to_net_u16(ETHERNET_TYPE_IP4))) {
        if ((unpack_sigs_from_opaque(b, &is_reverse, &pkt_sig, &sig, &df) >= 0) && df) {
            process_dpi(vm, worker, df, b,
                        &pkt_sig, &sig,
                        current_time, is_reverse);
        }
    } else {
        DPI_COUNTER_INC(vm, UNHANDLED, 1);
    }
}

VLIB_NODE_FN(dpi_node)(vlib_main_t         *vm,
                       vlib_node_runtime_t *node,
                       vlib_frame_t        *frame)
{
    u32                  n_left_from, *from, *to_next;
    dpi_next_t           next_index;
    dpi_main_t          *dpi = &dpi_main;
    u32                  CPT_PACKETS = 0;
    u32                  CPT_VOLUME = 0;

    from = vlib_frame_vector_args(frame);
    n_left_from = frame->n_vectors;
    next_index = node->cached_next_index;

    /* Retrieve CPU index and worker */
    u32 cpu_index = os_get_thread_index();
    struct qmdpi_worker *worker = dpi->workers_table[cpu_index];

    /* Retrieve current time */
    u32 current_time = (u32)((u64)(vm->cpu_time_last_node_dispatch /
                                   dpi->clocks_per_second));

    while (n_left_from > 0) {
        u32 n_left_to_next;

        vlib_get_next_frame(vm, node, next_index,
                            to_next, n_left_to_next);

        while (n_left_from >= 4 && n_left_to_next >= 2) {
            u32 next0 = dpi->next_node_index;
            u32 next1 = dpi->next_node_index;
            u32 sw_if_index0, sw_if_index1;
            ethernet_header_t *en0, *en1;
            u32 bi0, bi1;
            vlib_buffer_t *b0, *b1;

            /* Prefetch next iteration. */
            {
                vlib_buffer_t *p2, * p3;

                p2 = vlib_get_buffer(vm, from[2]);
                p3 = vlib_get_buffer(vm, from[3]);

                vlib_prefetch_buffer_header(p2, LOAD);
                vlib_prefetch_buffer_header(p3, LOAD);

                CLIB_PREFETCH(p2->data, CLIB_CACHE_LINE_BYTES, STORE);
                CLIB_PREFETCH(p3->data, CLIB_CACHE_LINE_BYTES, STORE);
            }

            /* speculatively enqueue b0 and b1 to the current next frame */
            to_next[0] = bi0 = from[0];
            to_next[1] = bi1 = from[1];
            from += 2;
            to_next += 2;
            n_left_from -= 2;
            n_left_to_next -= 2;

            b0 = vlib_get_buffer(vm, bi0);
            b1 = vlib_get_buffer(vm, bi1);

            ASSERT(b0->current_data == 0);
            ASSERT(b1->current_data == 0);

            process_dpi_data(vm, worker, b0, current_time);
            process_dpi_data(vm, worker, b1, current_time);
            CPT_PACKETS += 2;
            CPT_VOLUME += b0->current_length + b1->current_length;

            en0 = vlib_buffer_get_current(b0);
            en1 = vlib_buffer_get_current(b1);

            sw_if_index0 = vnet_buffer(b0)->sw_if_index[VLIB_RX];
            sw_if_index1 = vnet_buffer(b1)->sw_if_index[VLIB_RX];

            if (PREDICT_FALSE((node->flags & VLIB_NODE_FLAG_TRACE))) {
                if (b0->flags & VLIB_BUFFER_IS_TRACED) {
                    dpi_trace_t *t =
                        vlib_add_trace(vm, node, b0, sizeof(*t));
                    t->sw_if_index = sw_if_index0;
                    t->next_index = next0;
                    clib_memcpy(t->new_src_mac, en0->src_address,
                                sizeof(t->new_src_mac));
                    clib_memcpy(t->new_dst_mac, en0->dst_address,
                                sizeof(t->new_dst_mac));
                }
                if (b1->flags & VLIB_BUFFER_IS_TRACED) {
                    dpi_trace_t *t =
                        vlib_add_trace(vm, node, b1, sizeof(*t));
                    t->sw_if_index = sw_if_index1;
                    t->next_index = next1;
                    clib_memcpy(t->new_src_mac, en1->src_address,
                                sizeof(t->new_src_mac));
                    clib_memcpy(t->new_dst_mac, en1->dst_address,
                                sizeof(t->new_dst_mac));
                }
            }

            /* verify speculative enqueues, maybe switch current next frame */
            vlib_validate_buffer_enqueue_x2(vm, node, next_index,
                                            to_next, n_left_to_next,
                                            bi0, bi1, next0, next1);
        }

        while (n_left_from > 0 && n_left_to_next > 0) {
            u32 bi0;
            vlib_buffer_t *b0;
            u32 next0 = dpi->next_node_index;
            u32 sw_if_index0;
            ethernet_header_t *en0;

            /* speculatively enqueue b0 to the current next frame */
            bi0 = from[0];
            to_next[0] = bi0;
            from += 1;
            to_next += 1;
            n_left_from -= 1;
            n_left_to_next -= 1;

            b0 = vlib_get_buffer(vm, bi0);

            /*
             * Direct from the driver, we should be at offset 0
             * aka at &b0->data[0]
             */
            ASSERT(b0->current_data == 0);
            process_dpi_data(vm, worker, b0, current_time);
            CPT_PACKETS ++;
            CPT_VOLUME += b0->current_length;

            en0 = vlib_buffer_get_current(b0);
            sw_if_index0 = vnet_buffer(b0)->sw_if_index[VLIB_RX];
            if (PREDICT_FALSE((node->flags & VLIB_NODE_FLAG_TRACE)
                              && (b0->flags & VLIB_BUFFER_IS_TRACED))) {
                dpi_trace_t *t =
                    vlib_add_trace(vm, node, b0, sizeof(*t));
                t->sw_if_index = sw_if_index0;
                t->next_index = next0;
                clib_memcpy(t->new_src_mac, en0->src_address,
                            sizeof(t->new_src_mac));
                clib_memcpy(t->new_dst_mac, en0->dst_address,
                            sizeof(t->new_dst_mac));
            }

            /* verify speculative enqueue, maybe switch current next frame */
            vlib_validate_buffer_enqueue_x1(vm, node, next_index,
                                            to_next, n_left_to_next,
                                            bi0, next0);

        }

        vlib_put_next_frame(vm, node, next_index, n_left_to_next);
    }

    DPI_COUNTER_INC(vm, PACKETS, CPT_PACKETS);
    DPI_COUNTER_INC(vm, VOLUME, CPT_VOLUME);

    return frame->n_vectors;
}

/* *INDENT-OFF* */
#ifndef CLIB_MARCH_VARIANT
VLIB_REGISTER_NODE(dpi_node) = {
    .name = "dpi",
    .vector_size = sizeof(u32),
    .format_trace = format_dpi_trace,
    .type = VLIB_NODE_TYPE_INTERNAL,

    .n_errors = ARRAY_LEN(dpi_error_strings),
    .error_strings = dpi_error_strings,

    .n_next_nodes = DPI_N_NEXT,

    /* edit / add dispositions here */
    .next_nodes = {
        [DPI_NEXT_ETHERNET_INPUT] = "ethernet-input",
    },
};
#endif /* CLIB_MARCH_VARIANT */
/* *INDENT-ON* */
/*
 * fd.io coding-style-patch-verification: ON
 *
 * Local Variables:
 * eval: (c-set-style "gnu")
 * End:
 */
