//go:build linux && seccomp

// seccompagent is an example implementation of a seccomp-agent for the seccomp
// user notification feature. It intercepts a handful of system calls and
// emulates them.
//
// This tool is only intended to be used within runc's integration tests.
package main

import (
	"bytes"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"net"
	"os"
	"path/filepath"
	"strings"

	securejoin "github.com/cyphar/filepath-securejoin"
	"github.com/opencontainers/runc/internal/linux"
	"github.com/opencontainers/runtime-spec/specs-go"
	libseccomp "github.com/seccomp/libseccomp-golang"
	"github.com/sirupsen/logrus"
	"golang.org/x/sys/unix"
)

var (
	socketFile string
	pidFile    string
)

func closeStateFds(recvFds []int) {
	for _, fd := range recvFds {
		_ = unix.Close(fd)
	}
}

// parseStateFds returns the seccomp-fd and closes the rest of the fds in recvFds.
// In case of error, no fd is closed.
// StateFds is assumed to be formatted as specs.ContainerProcessState.Fds and
// recvFds the corresponding list of received fds in the same SCM_RIGHT message.
func parseStateFds(stateFds []string, recvFds []int) (uintptr, error) {
	// Let's find the index in stateFds of the seccomp-fd.
	idx := -1
	err := false

	for i, name := range stateFds {
		if name == specs.SeccompFdName && idx == -1 {
			idx = i
			continue
		}

		// We found the seccompFdName twice. Error out!
		if name == specs.SeccompFdName && idx != -1 {
			err = true
		}
	}

	if idx == -1 || err {
		return 0, errors.New("seccomp fd not found or malformed containerProcessState.Fds")
	}

	if idx >= len(recvFds) || idx < 0 {
		return 0, errors.New("seccomp fd index out of range")
	}

	fd := uintptr(recvFds[idx])

	for i := range recvFds {
		if i == idx {
			continue
		}

		unix.Close(recvFds[i])
	}

	return fd, nil
}

func handleNewMessage(sockfd int) (uintptr, string, error) {
	const maxNameLen = 4096
	stateBuf := make([]byte, maxNameLen)
	oobSpace := unix.CmsgSpace(4)
	oob := make([]byte, oobSpace)

	n, oobn, _, _, err := unix.Recvmsg(sockfd, stateBuf, oob, 0)
	if err != nil {
		return 0, "", err
	}
	if n >= maxNameLen || oobn != oobSpace {
		return 0, "", fmt.Errorf("recvfd: incorrect number of bytes read (n=%d oobn=%d)", n, oobn)
	}

	// Truncate.
	stateBuf = stateBuf[:n]
	oob = oob[:oobn]

	scms, err := unix.ParseSocketControlMessage(oob)
	if err != nil {
		return 0, "", err
	}
	if len(scms) != 1 {
		return 0, "", fmt.Errorf("recvfd: number of SCMs is not 1: %d", len(scms))
	}
	scm := scms[0]

	fds, err := unix.ParseUnixRights(&scm)
	if err != nil {
		return 0, "", err
	}

	containerProcessState := &specs.ContainerProcessState{}
	err = json.Unmarshal(stateBuf, containerProcessState)
	if err != nil {
		closeStateFds(fds)
		return 0, "", fmt.Errorf("cannot parse OCI state: %w", err)
	}

	fd, err := parseStateFds(containerProcessState.Fds, fds)
	if err != nil {
		closeStateFds(fds)
		return 0, "", err
	}

	return fd, containerProcessState.Metadata, nil
}

func readArgString(pid uint32, offset int64) (string, error) {
	buffer := make([]byte, 4096) // PATH_MAX

	memfd, err := linux.Open(fmt.Sprintf("/proc/%d/mem", pid), unix.O_RDONLY, 0o777)
	if err != nil {
		return "", err
	}
	defer unix.Close(memfd)

	_, err = unix.Pread(memfd, buffer, offset)
	if err != nil {
		return "", err
	}

	buffer[len(buffer)-1] = 0
	s := buffer[:bytes.IndexByte(buffer, 0)]
	return string(s), nil
}

func runMkdirForContainer(pid uint32, fileName string, mode uint32, metadata string) error {
	// We validated before that metadata is not a string that can make
	// newFile a file in a different location other than root.
	newFile := fmt.Sprintf("%s-%s", fileName, metadata)
	root := fmt.Sprintf("/proc/%d/cwd/", pid)

	if strings.HasPrefix(fileName, "/") {
		// If it starts with /, use the rootfs as base
		root = fmt.Sprintf("/proc/%d/root/", pid)
	}

	path, err := securejoin.SecureJoin(root, newFile)
	if err != nil {
		return err
	}

	return unix.Mkdir(path, mode)
}

// notifHandler handles seccomp notifications and responses
func notifHandler(fd libseccomp.ScmpFd, metadata string) {
	defer unix.Close(int(fd))
	for {
		req, err := libseccomp.NotifReceive(fd)
		if err != nil {
			logrus.Errorf("Error in NotifReceive(): %s", err)
			continue
		}
		syscallName, err := req.Data.Syscall.GetName()
		if err != nil {
			logrus.Errorf("Error decoding syscall %v(): %s", req.Data.Syscall, err)
			continue
		}
		logrus.Debugf("Received syscall %q, pid %v, arch %q, args %+v", syscallName, req.Pid, req.Data.Arch, req.Data.Args)

		resp := &libseccomp.ScmpNotifResp{
			ID:    req.ID,
			Error: 0,
			Val:   0,
			Flags: libseccomp.NotifRespFlagContinue,
		}

		// TOCTOU check
		if err := libseccomp.NotifIDValid(fd, req.ID); err != nil {
			logrus.Errorf("TOCTOU check failed: req.ID is no longer valid: %s", err)
			continue
		}

		switch syscallName {
		case "mkdir":
			fileName, err := readArgString(req.Pid, int64(req.Data.Args[0]))
			if err != nil {
				logrus.Errorf("Cannot read argument: %s", err)
				resp.Error = int32(unix.ENOSYS)
				resp.Val = ^uint64(0) // -1
				goto sendResponse
			}

			logrus.Debugf("mkdir: %q", fileName)

			// TOCTOU check
			if err := libseccomp.NotifIDValid(fd, req.ID); err != nil {
				logrus.Errorf("TOCTOU check failed: req.ID is no longer valid: %s", err)
				continue
			}

			err = runMkdirForContainer(req.Pid, fileName, uint32(req.Data.Args[1]), metadata)
			if err != nil {
				resp.Error = int32(unix.ENOSYS)
				resp.Val = ^uint64(0) // -1
			}
			resp.Flags = 0
		case "chmod", "fchmod", "fchmodat":
			resp.Error = int32(unix.ENOMEDIUM)
			resp.Val = ^uint64(0) // -1
			resp.Flags = 0
		}

	sendResponse:
		if err = libseccomp.NotifRespond(fd, resp); err != nil {
			logrus.Errorf("Error in notification response: %s", err)
			continue
		}
	}
}

func main() {
	flag.StringVar(&socketFile, "socketfile", "/run/seccomp-agent.socket", "Socket file")
	flag.StringVar(&pidFile, "pid-file", "", "Pid file")
	logrus.SetLevel(logrus.DebugLevel)

	// Parse arguments
	flag.Parse()
	if flag.NArg() > 0 {
		flag.PrintDefaults()
		logrus.Fatal("Invalid command")
	}

	if err := os.Remove(socketFile); err != nil && !errors.Is(err, os.ErrNotExist) {
		logrus.Fatalf("Cannot cleanup socket file: %v", err)
	}

	if pidFile != "" {
		pid := fmt.Sprintf("%d", os.Getpid())
		if err := os.WriteFile(pidFile, []byte(pid), 0o644); err != nil {
			logrus.Fatalf("Cannot write pid file: %v", err)
		}
	}

	logrus.Info("Waiting for seccomp file descriptors")
	l, err := net.Listen("unix", socketFile)
	if err != nil {
		logrus.Fatalf("Cannot listen: %s", err)
	}
	defer l.Close()

	for {
		conn, err := l.Accept()
		if err != nil {
			logrus.Errorf("Cannot accept connection: %s", err)
			continue
		}
		socket, err := conn.(*net.UnixConn).File()
		conn.Close()
		if err != nil {
			logrus.Errorf("Cannot get socket: %v", err)
			continue
		}
		newFd, metadata, err := handleNewMessage(int(socket.Fd()))
		socket.Close()
		if err != nil {
			logrus.Errorf("Error receiving seccomp file descriptor: %v", err)
			continue
		}

		// Make sure we don't allow strings like "/../p", as that means
		// a file in a different location than expected. We just want
		// safe things to use as a suffix for a file name.
		metadata = filepath.Base(metadata)
		if strings.Contains(metadata, "/") {
			// Fallback to a safe string.
			metadata = "agent-generated-suffix"
		}

		logrus.Infof("Received new seccomp fd: %v", newFd)
		go notifHandler(libseccomp.ScmpFd(newFd), metadata)
	}
}
