#!/usr/bin/env lua

--require "format"

--[[
  a    file access mode
  c    process command name (all characters from proc or user structure)
  C    file structure share count
  d    file's device character code
  D    file's major/minor device number (0x<hexadecimal>)
  f    file descriptor (always selected)
  F    file structure address (0x<hexadecimal>)
  G    file flaGs (0x<hexadecimal>; names if +fg follows)
  g    process group ID
  i    file's inode number
  K    tasK ID
  k    link count
  l    file's lock status
  L    process login name
  m    marker between repeated output
  n    file name, comment, Internet address
  N    node identifier (ox<hexadecimal>
  o    file's offset (decimal)
  p    process ID (always selected)
  P    protocol name
  r    raw device number (0x<hexadecimal>)
  R    parent process ID
  s    file's size (decimal)
  S    file's stream identification
  t    file's type
  T    TCP/TPI information, identified by prefixes (the
       `=' is part of the prefix):
           QR=<read queue size>
           QS=<send queue size>
           SO=<socket options and values> (not all dialects)
           SS=<socket states> (not all dialects)
           ST=<connection state>
           TF=<TCP flags and values> (not all dialects)
           WR=<window read size>  (not all dialects)
           WW=<window write size>  (not all dialects)
       (TCP/TPI information isn't reported for all supported
         UNIX dialects. The -h or -? help output for the
         -T option will show what TCP/TPI reporting can be
         requested.)
  u    process user ID
  z    Solaris 10 and higher zone name
  Z    SELinux security context (inhibited when SELinux is disabled)
  0    use NUL field terminator character in place of NL
  1-9  dialect-specific field identifiers (The output of -F? identifies the information to be found in dialect-specific fields.)
]]


local function printf(fmt, ...)
	io.write(string.format(fmt, ...))
end

--
-- Parse lsof output into lua tables
--

local function parse_lsof()

	local procs = {}

	local cur, proc, file

	for l in io.lines() do

		if l:find("^COMMAND") then
			io.stderr:write("Unexpected input, did you run lsof with the `-F` flag?\n")
			os.exit(1)
		end

		local tag, val = l:sub(1, 1), l:sub(2)

		if tag == 'p' then
			if not procs[val] then
				proc = { files = {} }
				file = nil
				cur = proc
				procs[val] = proc
			else
				proc = nil
				cur = nil
			end
		elseif tag == 'f' and proc then
			file = { proc = proc }
			cur = file
			table.insert(proc.files, file)
		end
	
		if cur then
			cur[tag] = val
		end
			
		-- skip kernel threads

		if proc then
			if file and file.f == "txt" and file.t == "unknown" then
				procs[proc.p] = nil
				proc = nil
				file = nil
				cur = nil
			end
		end

	end

	return procs
end


local function find_conns(procs)

	local cs = {
		fifo = {}, -- index by inode
		unix = {}, -- index by inode
		tcp  = {}, -- index by sorted endpoints
		udp  = {}, -- index by sorted endpoints
		pipe = {}, -- index by sorted endpoints
	}

	for pid, proc in pairs(procs) do
		for _, file in ipairs(proc.files) do

			if file.t == "unix" then
				local us = cs.unix
				local i = file.i or file.d
				us[i] = us[i] or {}
				table.insert(us[i], file)
			end

			if file.t == "FIFO" then
				local fs = cs.fifo
				fs[file.i] = fs[file.i] or {}
				table.insert(fs[file.i], file)
			end

			if file.t == "PIPE" then -- BSD/MacOS
				for n in file.n:gmatch("%->(.+)") do
					local ps = { file.d, n }
					table.sort(ps)
					local id = table.concat(ps, "\\n")
					local fs = cs.pipe
					fs[id] = fs[id] or {}
					table.insert(fs[id], file)
				end
			end

			if file.t == "IPv4" or file.t == "IPv6" then
				local a, b = file.n:match("(.-)%->(.+)")
				local ps = { a, b }
				table.sort(ps)
				local id = table.concat(ps, "\\n")
				local fs = (file.P == "TCP") and cs.tcp or cs.udp
				fs[id] = fs[id] or {}
				table.insert(fs[id], file)
			end
		end
	end

	return cs

end



local procs = parse_lsof()
local conns = find_conns(procs)


-- Generate graph

printf("digraph G {\n")
printf("	graph [ center=true, margin=0.2, nodesep=0.1, ranksep=0.3, rankdir=LR];\n")
printf("	node [ shape=box, style=\"rounded,filled\" width=0, height=0, fontname=Helvetica, fontsize=10];\n")
printf("	edge [ fontname=Helvetica, fontsize=10];\n")

-- Parent/child relationships

for pid, proc in pairs(procs) do

	local color = (proc.R == "1") and "grey70" or "white"
	printf('	p%d [ label = "%s\\n%d %s" fillcolor=%q ];\n', proc.p, proc.n or proc.c, proc.p, proc.L, color)

	local proc_parent = procs[proc.R]
	if proc_parent then
		if proc_parent.p ~= "1" then
			printf('	p%d -> p%d [ penwidth=2 weight=100 color=grey60 dir="none" ];\n', proc.R, proc.p)
		end
	end
end

-- Connections

local colors = {
	fifo = "green",
	unix = "purple",
	tcp  = "red",
	udp  = "orange",
	pipe = "orange",
}

for type, conn in pairs(conns) do
	for id, files in pairs(conn) do

		-- one-on-one connections
		
		if #files == 2 then

			local f1, f2 = files[1], files[2]
			local p1, p2 = f1.proc, f2.proc
			
			if p1 ~= p2 then
				local label = type .. ":\\n" .. id
				local dir = "both"
				if f1.a == "w" then 
					dir = "forward"
				elseif f1.a == "r" then
					dir = "backward"
				end
				printf('	p%d -> p%d [ color="%s" label="%s" dir="%s"];\n', p1.p, p2.p, colors[type] or "black", label, dir)
			end

		end
	end
end

-- Done

printf("}\n")

-- vi: ft=lua ts=3 sw=3

