/*
 *****************************************************************************
 * Copyright (C) 2017, Cisco Systems
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 ****************************************************************************
 *
 *  File:    utils.c
 *  Author:  Koushik Chakravarty <kouchakr@cisco.com>
 *
 ****************************************************************************
 *
 *  This file contains the utility apis
 *
 *****************************************************************************
 */

#include "utils.h"
#include "dbgout.h"

#ifdef NVM_BPF_USERSPACE
#include <limits.h>
#else
#include <linux/net.h>
#include <linux/sched.h>
#include <linux/fs.h>
#include <linux/file.h>
#include <linux/time.h>
#include <linux/slab.h>
#include "defines.h"

#if (LINUX_VERSION_CODE >= KERNEL_VERSION(4, 13, 0))
#include <linux/sched/task.h>
#include <linux/sched/mm.h>
#else
#include <linux/mm.h>
#endif
#endif
const char *default_name = "Unknown";	/*default task name */

/*
*   \brief Get the current task_struct
*/
struct task_struct *get_curr_task(void)
{
#ifdef NVM_BPF_USERSPACE
    return NULL;
#else
	struct task_struct *ret_val = current;

	if (NULL != ret_val)
		get_task_struct(ret_val);
	return ret_val;
#endif
}

/*
*   \brief Get the processid from the task_struct
*/
pid_t get_pid_of_task(struct task_struct *task)
{
	/*for a process, pid == tgid, for a thread pid != tgid and
	 * the tgid is the actual pid of the process */
	return (NULL == task) ? 0 : task->tgid;
}

struct task_struct *get_task_from_pid(pid_t uPid)
{
#ifdef NVM_BPF_USERSPACE
    // In user space, create and populate a task_struct
    if (uPid <= 0) {
        return NULL;
    }
    
    char proc_pid_path[PATH_MAX];
    int ret = snprintf(proc_pid_path, sizeof(proc_pid_path), "/proc/%d", uPid);
    if (ret < 0 || ret >= sizeof(proc_pid_path)) {
        return NULL;
    }
    // If the process directory does not exist, process is not alive.
    if (access(proc_pid_path, F_OK) != 0) {
        return NULL;
    }

    struct task_struct *task = (struct task_struct *)malloc(sizeof(struct task_struct));
    if (!task) {
        return NULL;
    }

    // Initialize the task structure
    task->tgid = uPid;
    // Get the task name
    if (get_taskname(task, task->comm, sizeof(task->comm)) == 0) {
        // If get_taskname fails to get a name, set a default value.
        strncpy(task->comm, "unknown", sizeof(task->comm)-1);
        task->comm[sizeof(task->comm)-1] = '\0';
    }
    // Set default start time will be updated by nvm_plugin
    // when the process creation time is available from the ebpf
    task->start_time = 0;
    return task;
#else
	if (0 == uPid)
		return NULL;
	return get_pid_task(find_vpid(uPid), PIDTYPE_PID);
#endif
}

/*
*   \brief Get the Parent task struct from the current task struct
*/
struct task_struct *get_parent(struct task_struct *task)
{
#ifdef NVM_BPF_USERSPACE
    if (!task) {
        TRACE(ERROR, LOG("Invalid task struct"));
        return NULL;
    }

    // Get parent PID from task struct
    pid_t ppid = task->tgid;
    
    // Read parent PID from /proc/{pid}/stat
    char stat_path[PATH_MAX];
    int ret = snprintf(stat_path, sizeof(stat_path), "/proc/%d/stat", ppid);
    if (ret < 0 || ret >= sizeof(stat_path)) {
        return NULL;
    }
    
    FILE *fp = fopen(stat_path, "r");
    if (!fp) {
        return NULL;
    }

    // Parse stat file to get parent PID (4th field)
    char buf[512];
    if (fgets(buf, sizeof(buf), fp)) {
        char *rparen = strrchr(buf, ')');
        if (rparen) {
            pid_t parent_pid;
            if (sscanf(rparen + 2, "%*c %d", &parent_pid) == 1) {
                if (parent_pid > 0) {  // Check for valid PID
                    fclose(fp);
                    // Get parent task_struct using the parent PID
                    return get_task_from_pid(parent_pid);
                }

            }
        }
    }
    
    fclose(fp);
    return NULL;
#else
	struct task_struct *ret_val = (NULL == task) ? NULL : task->parent;

	if (NULL != ret_val)
		get_task_struct(ret_val);
	return ret_val;
#endif
}

/*
*   \brief Remove reference to a retrieved task struct
*/
void unref_task(struct task_struct *task)
{
	if (NULL != task)
#ifdef NVM_BPF_USERSPACE
        free(task);
#else
		put_task_struct(task);
#endif
}

#ifndef NVM_BPF_USERSPACE
/*
*   \brief Get the executable path from the task struct
*/
uint16_t GetExePathFromTaskGeneric(struct task_struct *task, char *path_buffer,
				   uint16_t buffer_size, bool bCurrent)
{
	uint16_t retval = 0;
	struct mm_struct *mem_mgr = NULL;
	char *temp_buffer = NULL;
	const char *path = default_name;

	if (NULL == task || NULL == path_buffer || 0 == buffer_size) {
		TRACE(ERROR, LOG("Invalid parameters"));
		return retval;
	}

	temp_buffer = KMALLOC(PATH_MAX);
	if (NULL == temp_buffer) {
		TRACE(ERROR, LOG("Failed to allocate temporary buffer"));
		return retval;
	}
	memset(temp_buffer, 0, PATH_MAX);
	/* We are not calling get_task_mm/mmput as we
	   are in the context of the same task and mmput can sleep() */
	mem_mgr = (bCurrent) ? task->mm : get_task_mm(task);
	if (NULL == mem_mgr || NULL == mem_mgr->exe_file) {
		TRACE(WARNING, LOG("Failed to get task details"));
		goto done;
	}
#if (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 8, 0))
	mmap_read_lock(mem_mgr);
#else
	down_read(&mem_mgr->mmap_sem);
#endif
	path = d_path(&mem_mgr->exe_file->f_path, temp_buffer, PATH_MAX);
	if (IS_ERR(path)) {
		TRACE(WARNING, LOG("Failed to get file path"));
		path = default_name;
	}
#if (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 8, 0))
	mmap_read_unlock(mem_mgr);
#else
	up_read(&mem_mgr->mmap_sem);
#endif
done:
    if (NULL != mem_mgr && !bCurrent)
        mmput(mem_mgr);

    snprintf(path_buffer, buffer_size, "%s", path);
    retval = strlen(path_buffer);
    if (NULL != temp_buffer)
        KFREE(temp_buffer);

    return retval;
}
/*
 *   \brief Wrapper to get the executable path from the current task struct
*/
uint16_t
get_exepath_from_curr_task(struct task_struct *task, char *path_buffer,
			   uint16_t buffer_size)
{
	return GetExePathFromTaskGeneric(task, path_buffer, buffer_size, true);
}
#endif

/*
*   \brief Wrapper to get the executable path from a task struct
*/
uint16_t get_exepath_from_task(struct task_struct *task, char *path_buffer,
			       uint16_t buffer_size)
{
#ifdef NVM_BPF_USERSPACE
    path_buffer[0] = '\0';
    return 0; // In case of ebpf we don't fetch the path as from ProcessTree Plugin will
              // be used to get the path
#else
	return GetExePathFromTaskGeneric(task, path_buffer, buffer_size, false);
#endif
}

/**
 *  \brief To collect Process Creation Time in  Milliseconds
 */
void get_process_creation_time(struct task_struct *task, uint64_t *time)
{
	if (NULL == task || NULL == time) {
		TRACE(ERROR, LOG("Invalid parameters"));
		*time = 0;
		return;
	}
	*time = (task->start_time)/1000000; // conversion from nanoseconds to milliseconds
}

/*
 * Api to get the current time
 */
uint32_t get_unix_systime(void)
{
#ifdef NVM_BPF_USERSPACE
    struct timeval now;
    if (gettimeofday(&now, NULL) != 0) {
        return 0;
    }
    return (uint32_t)now.tv_sec;
#else
#if (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 0, 0))
    struct timespec64 now;

    ktime_get_real_ts64(&now);
    return now.tv_sec;
#else
    struct timeval now;

    do_gettimeofday(&now);
    return now.tv_sec;
#endif
#endif
}

/*
*   \brief Get the executable name from the task struct
*/
uint16_t get_taskname(struct task_struct *task, char *name_buffer,
		      uint16_t buffer_size)
{
#ifdef NVM_BPF_USERSPACE
    if (!task || !name_buffer || buffer_size < 2) {  // Ensure minimum buffer size
        return 0;
    }

    // Initialize buffer to prevent undefined behavior
    name_buffer[0] = '\0';

	if (task->comm[0] != '\0')
    {
        // If comm field is already populated in task_struct, use it
	strncpy(name_buffer, task->comm, buffer_size);
        uint16_t len = strlen(name_buffer);
        return len;
    }

    // Try reading from /proc/[pid]/comm
    char comm_path[PATH_MAX];
    int ret = snprintf(comm_path, sizeof(comm_path), "/proc/%d/comm", task->tgid);
    if (ret < 0 || ret >= sizeof(comm_path)) {
        name_buffer[0] = '\0';
        return 0;
    }
    
    FILE *fp = fopen(comm_path, "r");
    if (!fp)
    {
        name_buffer[0] = '\0';
        return 0;
    }

    char *line = NULL;
    size_t len = 0;
    ssize_t read;

    read = getline(&line, &len, fp);
    fclose(fp);

    if (read <= 0)
    {
        if (line) free(line);
        name_buffer[0] = '\0';
        return 0;
    }

    // Remove trailing newline if present
    if (line[read - 1] == '\n')
    {
        line[read - 1] = '\0';
        read--;
    }

    // Copy to output buffer with size checking
    strncpy(name_buffer, line, buffer_size);
    uint16_t copy_len = strlen(name_buffer);
    
    free(line);
    
    return copy_len;
#else
	char buf[TASK_COMM_LEN] = {0};

	if (NULL == task || NULL == name_buffer || 0 == buffer_size
	    || buffer_size <= TASK_COMM_LEN) {
		TRACE(ERROR, LOG("Invalid parameters"));
		return 0;
	}
	get_task_comm(buf, task);
	snprintf(name_buffer, buffer_size, "%s", ('\0' == buf[0] ? default_name : buf));
	return strlen(name_buffer);
#endif
}

/**
 * \brief sends data over socket
 * \description
 *
 * \param[in] local local socket via which the data is to be sent
 * \param[in] dest destination socket
 * \param[in] buffer data to be sent
 * \param[in] buff_len length of the buffer to be sent
 *
 * \return status code
*/
error_code socket_sendto(struct socket *local, struct sockaddr_in *dest,
			   const uint8_t *buffer, size_t buff_len)
{
#ifndef NVM_BPF_USERSPACE
	struct msghdr msg = {0};
	struct kvec vec = {0};
	size_t bytes_sent = 0;

	if ((NULL == local) || (NULL == dest) || (NULL == buffer)) {
		TRACE(ERROR,
		      LOG
		      ("Unable to send data over socket. Invalid parameters"));
		return ERROR_BAD_PARAM;
	}

	msg.msg_flags = MSG_DONTWAIT;
	msg.msg_name = dest;
	msg.msg_namelen = sizeof(struct sockaddr_in);
	msg.msg_control = NULL;
	msg.msg_controllen = 0;

	vec.iov_base = (uint8_t *) buffer;
	vec.iov_len = buff_len;

	bytes_sent = kernel_sendmsg(local, &msg, &vec, 1, vec.iov_len);

	if (buff_len != bytes_sent) {
		TRACE(ERROR,
		      LOG
		      ("kernel_sendmsg: sent %d, ret %d", buff_len, bytes_sent));
	}
	return ((buff_len == bytes_sent) ? SUCCESS : ERROR_ERROR);
#else
    return SUCCESS;
#endif
}
