// Package client implements a rsync client and functionality around the protocol
package client

// #include <stdlib.h>
// #include "native_funcs.h"
import (
	"C"
)

import (
	"bytes"
	"encoding/binary"
	"fmt"
	"net"
	"path/filepath"
	"runtime/debug"
	"slices"
	"sort"
	"strconv"
	"strings"
	"unsafe"

	"cvr.google.com/pkg/util"
)

// MplexBase is the base for all multiplex tags
const MplexBase = 7

// MsgData is the tag for the MSG_DATA packet
const MsgData = 0

// TagShift is the number of bits to shift a tag to the left
const TagShift = 24

// NdxDone is a delimiter for different packets
const NdxDone = -1

// ItemTransferFlag signals to rsync that we actually want to download the file and not just
// log it
const ItemTransferFlag = 1 << 15

// ItemBasisTypeFollows signals to rsync that another few bytes will be sent
const ItemBasisTypeFollows = 1 << 11

// ItemXnameFollows signals to rsync that the xname field will follow
const ItemXnameFollows = 1 << 12

// FnamecmpFname is a flag that tells Rsync to write to a temporary file before writing to the
// destination file. This is to prevent rsync from writing partial data to the destination file.
const FnamecmpFname = 0x80

// IntByteExtra is a lookup table for the extra bytes in a varint
// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/io.c#L119
var IntByteExtra = [64]int{
	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* (00 - 3F)/4 */
	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* (40 - 7F)/4 */
	1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, /* (80 - BF)/4 */
	2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 5, 6, /* (C0 - FF)/4 */
}

const (
	xmitSameMode     = 1 << 1
	xmitSameName     = 1 << 5
	xmitLongName     = 1 << 6
	xmitSameTime     = 1 << 7
	xmitHLinked      = 1 << 9
	xmitHLinkedFirst = 1 << 12
	xmitModNsec      = 1 << 13
)

func bitsSetAndUnset(val int, onbits int, offbits int) bool {
	return (((val) & ((onbits) | (offbits))) == (onbits))
}

// File represents a file that is available for download on the server
type File struct {
	Name          string
	Size          int
	Mode          uint32
	Digest        []byte
	SymlinkTarget string
	Content       []byte
}

// FileList represents a list of files that are available for download on the server
type FileList struct {
	files      []*File
	childLists []*FileList
}

// AddRegularFile adds a file to the file list
func (fl *FileList) AddRegularFile(fileName string, content []byte) *File {
	f := &File{
		Name:    fileName,
		Content: content,
	}
	// we don't have varint30 implemented yet so just set it to 0 for now
	f.Size = 0
	f.Mode = util.SIFREG | 0755
	f.SymlinkTarget = ""
	fl.files = append(fl.files, f)
	return f
}

// AddSymlink adds a symlink file to the file list
func (fl *FileList) AddSymlink(fileName string, target string) *File {
	f := &File{
		Name:          fileName,
		SymlinkTarget: target,
	}
	// we don't have varint30 implemented yet so just set it to 0 for now
	f.Size = 0
	f.Mode = util.SIFLNK
	f.Digest = []byte{}
	fl.files = append(fl.files, f)

	return f
}

// AddDir adds a directory to the file list
func (fl *FileList) AddDir(dirName string) *File {
	f := &File{
		Name: dirName,
	}
	// we don't have varint30 implemented yet so just set it to 0 for now
	f.Size = 0
	f.Mode = util.SIFDIR | 0755
	f.SymlinkTarget = ""
	fl.files = append(fl.files, f)
	return f
}

// NewFileList  creates a new file list
func NewFileList() *FileList {
	return &FileList{
		files: []*File{},
	}
}

// NewChildList  adds a child file list to the file list
func (fl *FileList) NewChildList() *FileList {
	newFl := NewFileList()
	fl.childLists = append(fl.childLists, newFl)
	return newFl
}

// IndexFor returns the index of the file with the given name
func (fl *FileList) IndexFor(file *File) (int, error) {
	// idx starts at 1. DOn't ask me why
	idx := 1

	// sanity pass to make sure the file list is sorted
	fl.Sort()
	for _, childFl := range fl.childLists {
		childFl.Sort()
	}
	// search the root file list first
	for _, f := range fl.files {
		if f == file {
			return idx, nil
		}
		idx++
	}

	// search the child file lists
	for _, childFl := range fl.childLists {
		idx++ // each file list jumps by 1
		for _, f := range childFl.files {
			if f == file {
				return idx, nil
			}
			idx++
		}
	}

	return -1, fmt.Errorf("file not found: %s", file.Name)
}

// Sort sorts the file list
func (fl *FileList) Sort() {
	// the rsync client sorts the file list by name. This is relevant because we want our file lists
	// to be synchronized with the server.
	sort.Slice(fl.files, func(i, j int) bool {

		// this is always the smallest element in the list
		if fl.files[i].Name == "." {
			return true
		}

		if fl.files[j].Name == "." {
			return false
		}

		dirname1 := C.CString(filepath.Dir(fl.files[i].Name))
		dirname2 := C.CString(filepath.Dir(fl.files[j].Name))
		basename1 := C.CString(filepath.Base(fl.files[i].Name))
		basename2 := C.CString(filepath.Base(fl.files[j].Name))
		defer C.free(unsafe.Pointer(dirname1))
		defer C.free(unsafe.Pointer(dirname2))
		defer C.free(unsafe.Pointer(basename1))
		defer C.free(unsafe.Pointer(basename2))

		res := C.f_name_cmp(dirname1, dirname2, basename1, basename2, C.int(fl.files[i].Mode), C.int(fl.files[j].Mode), 31)
		return int(res) == -1
	})
}

// Send sends the file list to the server
func (fl *FileList) Send(srv *RsyncClient) error {
	// Firstly, sort the file list
	fl.Sort()

	for _, f := range fl.files {

		// send the xflags. No special flags are needed
		srv.WriteByte(xmitLongName | xmitSameTime)

		// send the file name length. Account for null byte
		err := srv.WriteVarInt(len(f.Name))
		if err != nil {
			return err
		}
		err = srv.WriteSbuf(f.Name)
		if err != nil {
			return err
		}

		// send the file length. We don't support varlong30 yet so just send 3 0s
		srv.WriteByte(0)
		srv.WriteByte(0)
		srv.WriteByte(0)

		// we don't need to send the modtime because we just set the xmitSameTime flag above

		// send the file mode
		err = srv.WriteRawInt(int(f.Mode))
		if err != nil {
			return err
		}

		// TODO: Figure out if the client wants to keep atimes etc

		if util.IsSymlink(f.Mode) {
			// send the symlink target length
			err = srv.WriteVarInt(len(f.SymlinkTarget))
			if err != nil {
				return err
			}
			// for some reason, we send a null byte at the end of the symlink target in this special case
			err = srv.WriteSbuf(f.SymlinkTarget)
			if err != nil {
				return err
			}
		}

	}

	// send a 0 to indicate end of file list
	err := srv.WriteByte(0)
	if err != nil {
		return err
	}

	return nil
}

const (
	rInit = iota
	rIdle
	rRunning
	rInflating
	rInflated
)

const deflatedData = 0x40
const tokenRel = 0x80

// DeflateState mirrors global variables in rsync
type DeflateState struct {
	recvState     int
	isInitialized bool
	rxToken       int
	savedFlag     int
}

// RsyncClient is a rsync client
type RsyncClient struct {
	conn            net.Conn
	protocolVersion int
	Module          string
	digest          string
	digestLen       int

	// keeps track of the current state of the outbound connection
	outMultiplexed bool

	// keeps track of the current state of the inbound connection
	inMultiplexed bool

	// the incoming buffer to store MSG_DATA packets
	inputBuffer []byte

	// prevPositive and prevNegative are global variables in rsync. It's hacky but
	// we have no choice but to set them
	prevPositiveOutbound int
	prevNegativeOutbound int
	prevPositiveInbound  int
	prevNegativeInbound  int

	// more state that represents a global variable in rsync
	residue int

	flateState DeflateState
}

// shiftTag shifts a tag to the left by MplexBase and returns the result
func shiftTag(tag int) int {
	return (tag + MplexBase) << TagShift
}

func unshiftTag(tag int) (int, int) {
	msgBytes := tag & 0xFFFFFF
	tag = (tag >> TagShift) - MplexBase
	return tag, msgBytes
}

// Read reads bytes from the connection and takes into account if the connection is multiplexed
func (r *RsyncClient) Read(b []byte) error {
	if !r.inMultiplexed {
		_, err := r.conn.Read(b)
		return err
	}

	// the full data might arrive in multiple data packets.
	leftToRead := len(b)

	// before reading new data, check if we have enough data in the buffer
	if len(r.inputBuffer) >= leftToRead {
		// read the data into the buffer
		for i := 0; i < leftToRead; i++ {
			b[i] = r.inputBuffer[i]
		}

		// consume the bytes from the buffer
		r.inputBuffer = r.inputBuffer[leftToRead:]

		return nil
	}

	// set multiplexed to false so we can read the next packet
	r.inMultiplexed = false

	for {
		// read the current tag
		tag, err := r.ReadInt()
		if err != nil {
			return err
		}

		tag, msgBytes := unshiftTag(tag)

		// in such cases, we can just consume the tag and ignore the data
		// this can be the case for keep_alive messages etc
		if msgBytes == 0 {
			continue
		}

		// fmt.Printf("[*] tag=%d msg_bytes=%d\n", tag, msgBytes)
		if tag == MsgData {
			tmpBuf := make([]byte, msgBytes)
			_, err = r.conn.Read(tmpBuf)
			if err != nil {
				return err
			}

			// write the data to the buffer
			r.inputBuffer = append(r.inputBuffer, tmpBuf...)

			// check if we have enough data in the buffer by now.
			// Otherwise wait for another MSG_DATA packet.
			if len(r.inputBuffer) >= leftToRead {
				// read the data into the buffer
				for i := 0; i < leftToRead; i++ {
					b[i] = r.inputBuffer[i]
				}

				// consume the bytes from the buffer
				r.inputBuffer = r.inputBuffer[leftToRead:]

				// we can now exit the loop
				break
			}
		}

		// in other cases, we just need to consume the data
		// TODO: There might be special cases where this isn't the case!
		_, err = r.conn.Read(make([]byte, msgBytes))
		if err != nil {
			return err
		}

	}

	// restore the multiplexed state to true
	r.inMultiplexed = true
	return nil

}

// MsgDelete is the tag for the MSG_DELETE packet
const MsgDelete = 101

// MsgExit is the tag for the MSG_EXIT packet
const MsgExit = 86

// SendExitMessage signals the client to exit
func (r *RsyncClient) SendExitMessage() error {
	// the code is here:
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/io.c#L1548

	// save the multi plexed state so that this packet doesn't get wrapped in a MSG_DATA packet
	saved := r.outMultiplexed
	r.outMultiplexed = false

	tag := shiftTag(MsgExit)
	err := r.WriteRawInt(tag)

	r.outMultiplexed = saved
	return err
}

// SendDeleteMessage sends a multiplexed delete message to the server
func (r *RsyncClient) SendDeleteMessage() error {
	// the code is here:
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/io.c#L1548

	// save the multi plexed state so that this packet doesn't get wrapped in a MSG_DATA packet
	saved := r.outMultiplexed
	r.outMultiplexed = false

	tag := shiftTag(MsgDelete)
	payload := make([]byte, 512)
	for i := 0; i < len(payload); i++ {
		payload[i] = 'A'
	}

	tag |= len(payload)
	err := r.WriteRawInt(tag)
	if err != nil {
		return err
	}
	err = r.Write(payload)
	if err != nil {
		return err
	}

	r.outMultiplexed = saved
	return nil
}

// ReadLine consumes one line from the connection
func (r *RsyncClient) ReadLine() (string, error) {
	result := ""

	// read one byte at a time until we find a newline
	buf := make([]byte, 1)
	for {
		err := r.Read(buf)

		if err != nil {
			return "", err
		}
		if buf[0] == '\n' || buf[0] == '\x00' {
			break
		}

		result += string(buf[0])
	}

	return result, nil
}

// WriteLineOld writes a line to the connection. Each line is terminated by a newline character
func (r *RsyncClient) WriteLineOld(line string) error {
	return r.Write([]byte(line + "\n"))
}

// WriteLine writes a line to the connection. Each line is terminated by a newline character
func (r *RsyncClient) WriteLine(line string) error {
	return r.Write([]byte(line + "\x00"))
}

// Write writes a byte array to the connection
// it takes into account if the connection is currently multiplexed and if so
// wraps the packet in a msg_data packet
func (r *RsyncClient) Write(data []byte) error {
	if !r.outMultiplexed {
		_, err := r.conn.Write(data)
		return err
	}

	// temporarily disable multiplexing to prevent recursion
	r.outMultiplexed = false
	err := r.WriteMsgData(data)
	r.outMultiplexed = true
	return err
}

// WriteSbuf sends a string to the connection
func (r *RsyncClient) WriteSbuf(s string) error {
	return r.Write([]byte(s))
}

// EnableMultiPlexOutbound enables multiplexing for the outbound connection
func (r *RsyncClient) EnableMultiPlexOutbound() {
	r.outMultiplexed = true
}

// EnableMultiPlexInbound enables multiplexing for the inbound connection
func (r *RsyncClient) EnableMultiPlexInbound() {
	r.inMultiplexed = true
}

// WriteByte writes a single byte to the connection
func (r *RsyncClient) WriteByte(b byte) error {
	return r.Write([]byte{b})
}

// ReadByte reads a single byte from the connection
func (r *RsyncClient) ReadByte() (byte, error) {
	buf := make([]byte, 1)
	err := r.Read(buf)
	if err != nil {
		return 0, err
	}
	return buf[0], nil
}

// WriteRawInt writes a raw integer (with the endianess of the machine??) to the connection
func (r *RsyncClient) WriteRawInt(i int) error {
	b := []byte{}
	b = binary.LittleEndian.AppendUint32(b, uint32(i))
	return r.Write(b)
}

// WriteRawInt64 writes a raw integer (with the endianess of the machine??) to the connection
// This is used for the digest algorithm
func (r *RsyncClient) WriteRawInt64(i int) error {
	b := []byte{}
	b = binary.LittleEndian.AppendUint64(b, uint64(i))
	return r.Write(b)
}

// WriteShortInt writes a short integer to the connection
func (r *RsyncClient) WriteShortInt(i int) error {
	b := []byte{}
	b = binary.LittleEndian.AppendUint16(b, uint16(i))
	return r.Write(b)
}

// ReadInt reads an integer from the connection
func (r *RsyncClient) ReadInt() (int, error) {
	buf := make([]byte, 4)
	err := r.Read(buf)
	if err != nil {
		return 0, err
	}
	return int(binary.LittleEndian.Uint32(buf)), nil
}

// ReadShortInt reads a short integer from the connection
func (r *RsyncClient) ReadShortInt() (int, error) {
	buf := make([]byte, 2)
	err := r.Read(buf)
	if err != nil {
		return 0, err
	}
	return int(binary.LittleEndian.Uint16(buf)), nil
}

// ReadVarLong reads an 8 byte integer with at least `minBytes` consumed
// Copies the logic from https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/io.c#L1826
// TODO: I have no clue if this is correct. I'm just copying the logic from the rsync code as closely as I can
func (r *RsyncClient) ReadVarLong(minBytes int) (int64, error) {

	b := make([]byte, 9)

	b2 := make([]byte, 8)
	minBuf := make([]byte, minBytes)
	err := r.Read(minBuf)
	if err != nil {
		return 0, err
	}
	for i := 0; i < minBytes; i++ {
		b2[i] = minBuf[i]
	}

	// do the following logic:
	// `memcpy(u.b, b2+1, min_bytes-1);`
	for i := 0; i < minBytes-1; i++ {
		b[i] = b2[i+1]
	}
	// extraIdx := binary.LittleEndian.Uint64(b2) & 0xff
	extraIdx := b2[0]
	extra := IntByteExtra[extraIdx/4]

	if extra > 0 {
		// A little reminder for future me
		debug.PrintStack()
		fmt.Printf("[*] WARNING: Server sent %d extra bytes in ReadVarLong. I'm not confident about this code working. If there are any issues, doublecheck here\n", extra)
		bit := (1 << (8 - extra))
		extraTmpBuf := make([]byte, extra)
		err := r.Read(extraTmpBuf)
		if err != nil {
			return 0, err
		}
		for i := 0; i < extra; i++ {
			b[minBytes-1+i] = extraTmpBuf[i]
		}
		b[minBytes+extra-1] = b2[0] & (byte(bit) - 1)

	} else {
		b[minBytes+extra-1] = b2[0]
	}

	return int64(binary.LittleEndian.Uint64(b)), nil
}

// ReadVarLong30 reads a different amount of bytes depending on the protocol version
func (r *RsyncClient) ReadVarLong30(minBytes int) (int64, error) {
	if r.protocolVersion < 30 {
		return r.ReadLongInt()
	}
	return r.ReadVarLong(minBytes)
}

// ReadVarint30 reads a variable length integer depending on the protocol version
func (r *RsyncClient) ReadVarint30() (int, error) {
	if r.protocolVersion < 30 {
		return r.ReadInt()
	}
	return r.ReadVarInt()
}

// WriteMsgData wraps raw data in the MSG_DATA packet
// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/io.c#L1485-L1488
// ^ This is the code that reads the message data length from the tag
func (r *RsyncClient) WriteMsgData(data []byte) error {
	// write tag + length of the data in one integer
	if len(data) > 0xFFFFFF {
		return fmt.Errorf("data is too long: %d", len(data))
	}

	tag := shiftTag(MsgData)
	tag |= len(data)
	err := r.WriteRawInt(tag)
	if err != nil {
		return err
	}

	// now flush the actual data
	_, err = r.conn.Write(data)
	if err != nil {
		return err
	}

	return nil
}

// EncodeVString encodes a string in the format of a vstring
func EncodeVString(s string) ([]byte, error) {
	// copy the logic from
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/io.c#L2221
	lenbuf := []byte{0, 0, 0}
	lenbufIdx := 0

	var result bytes.Buffer

	strLen := len(s)
	if strLen > 0x7FFF {
		return nil, fmt.Errorf("vstring (%s) is too long: %d", s, strLen)
	}

	// send the first byte with the length of the string
	if strLen > 0x7f {
		lenbuf[lenbufIdx] = byte(strLen/0x100 + 0x80)
		lenbufIdx++
	}
	lenbuf[lenbufIdx] = byte(strLen)

	// send the vlength buffer
	result.Write(lenbuf[:lenbufIdx+1])
	result.WriteString(s)

	// now send the vstring inside a data packet
	return result.Bytes(), nil

}

// WriteNdx encodes a 4-byte integer using byte-reduction
// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/io.c#L2242
func (r *RsyncClient) WriteNdx(ndx int) error {
	// TODO: This is also true if read_batch is set to true
	if r.protocolVersion < 30 {
		return r.WriteRawInt(ndx)
	}

	// state of the current reduction
	diff := 0
	cnt := 0
	b := make([]byte, 6)
	if ndx >= 0 {
		diff = ndx - r.prevPositiveOutbound
		r.prevPositiveOutbound = ndx
	} else if ndx == NdxDone {
		return r.WriteByte(0)
	} else {
		b[cnt] = 0xff
		cnt++
		ndx = ndx * -1
		diff = ndx - r.prevNegativeOutbound
		r.prevNegativeOutbound = ndx
	}

	if diff < 0xFE && diff > 0 {
		b[cnt] = byte(diff)
		cnt++
	} else if diff < 0 || diff > 0x7FFF {
		b[cnt] = byte(0xFE)
		cnt++
		b[cnt] = byte(ndx>>24) | 0x80
		cnt++
		b[cnt] = byte(ndx)
		cnt++
		b[cnt] = byte(ndx >> 8)
		cnt++
		b[cnt] = byte(ndx >> 16)
		cnt++
	} else {
		b[cnt] = 0xFE
		cnt++
		b[cnt] = byte(diff >> 8)
		cnt++
		b[cnt] = byte(diff)
		cnt++
	}

	return r.Write(b[:cnt])
}

const (
	prevNegativePtr = iota
	prevPositivePtr
)

// rsync uses a pointer to update the previous value
// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/io.c#L2292
func (r *RsyncClient) updatePrevInbound(prevPtr int, value int) {
	if prevPtr == prevPositivePtr {
		r.prevPositiveInbound = value
	} else {
		r.prevNegativeInbound = value
	}
}

// getPrevInbound returns the previous value
func (r *RsyncClient) getPrevInbound(prevPtr int) int {
	if prevPtr == prevPositivePtr {
		return r.prevPositiveInbound
	}
	return r.prevNegativeInbound
}

// ReadNdx reads a 4-byte integer using byte-reduction
// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/io.c#L2289
func (r *RsyncClient) ReadNdx() (int, error) {

	prevPtr := 0
	// Life could have been so easy
	num := 0
	if r.protocolVersion < 30 {
		return r.ReadInt()
	}

	b, err := r.ReadByte()
	if err != nil {
		return 0, err
	}

	if b == 0xFF {
		b, err = r.ReadByte()
		if err != nil {
			return 0, err
		}
		prevPtr = prevNegativePtr
	} else if b == 0 {
		return NdxDone, nil
	} else {
		prevPtr = prevPositivePtr
	}

	if b == 0xFE {
		buf := make([]byte, 4)
		// read 2 more bytes
		for i := 0; i < 2; i++ {
			b, err = r.ReadByte()
			if err != nil {
				return 0, err
			}
			buf[i] = b
		}

		if buf[0]&0x80 != 0 {
			// this is actually supposed to be ^0x80 but Golang complains about this
			// we don't actually ever use the index the server sends us (I think?) so this is fine
			// to get wrong. We mostly care about reading the right amount of bytes
			buf[3] = buf[0] & 0x80
			buf[0] = buf[1]
			// read 2 more bytes
			for i := 0; i < 2; i++ {
				b, err = r.ReadByte()
				if err != nil {
					return 0, err
				}
				buf[i+1] = b
			}
			num = int(binary.LittleEndian.Uint32(buf))
		} else {
			num = (int(buf[0]) << 8) + int(buf[1]) + r.getPrevInbound(prevPtr)
		}

	} else {
		num = int(b) + r.getPrevInbound(prevPtr)
	}

	r.updatePrevInbound(prevPtr, num)
	if prevPtr == prevNegativePtr {
		num = -num
	}

	return num, nil

}

// WriteVString writes a variable length string to the connection
func (r *RsyncClient) WriteVString(s string) error {

	result, err := EncodeVString(s)
	if err != nil {
		return err
	}
	// err = r.WriteMsgData(result)
	return r.Write(result)
}

// ReadVarInt reads a variable length integer from the connection
func (r *RsyncClient) ReadVarInt() (int, error) {

	// copy the logic from
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/io.c#L1794
	buf := []uint8{0, 0, 0, 0, 0}
	chSigned, err := r.ReadByte()
	ch := uint8(chSigned)
	if err != nil {
		return 0, err
	}
	extra := IntByteExtra[ch/4]
	if extra > 0 {
		bit := uint8(1 << (8 - extra))

		// read extra bytes into buf
		if extra >= len(buf) {
			return 0, fmt.Errorf("extra bytes (%d) are too many", extra)
		}

		for i := 0; i < extra; i++ {
			currByte, err := r.ReadByte()
			if err != nil {
				return 0, err
			}
			buf[i] = currByte
		}
		buf[extra] = ch & (bit - 1)

	} else {
		buf[0] = ch
	}

	return int(binary.LittleEndian.Uint32(buf)), nil
}

// WriteVarInt writes a variable length integer to the connection
// Logic from https://github.com/RsyncProject/rsync/blob/master/io.c#L2088
// TODO: No idea if this is correct.
func (r *RsyncClient) WriteVarInt(num int) error {
	// copy the logic from
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/io.c#L1794
	buf := []uint8{0}
	buf = binary.LittleEndian.AppendUint32(buf, uint32(num))
	var bit byte
	cnt := 0

	for cnt = 4; cnt > 1 && buf[cnt] == 0; cnt-- {
	}
	bit = (byte)(1 << (7 - cnt + 1))
	if buf[cnt] >= bit {
		cnt++
		buf[0] = ^(bit - 1)
	} else if cnt > 1 {
		buf[0] = buf[cnt] | ^(bit*2 - 1)
	} else {
		buf[0] = buf[1]
	}

	r.Write(buf[:cnt])
	return nil
}

// ReadLongInt reads a long integer from the connection
func (r *RsyncClient) ReadLongInt() (int64, error) {
	num, err := r.ReadInt()
	if err != nil {
		return 0, err
	}

	// the easy case
	if num != 0xffffffff {
		return int64(num), nil
	}

	// otherwise read 8 bytes
	buf := []byte{0, 0, 0, 0, 0, 0, 0, 0}
	err = r.Read(buf)

	return int64(binary.LittleEndian.Uint64(buf)), err
}

// ReadVString reads a variable length string from the connection
func (r *RsyncClient) ReadVString() (string, error) {
	// read the first byte with the length of the string
	b, err := r.ReadByte()
	if err != nil {
		return "", err
	}
	len := int(b)
	if len&0x80 > 0 {
		b, err = r.ReadByte()
		if err != nil {
			return "", err
		}
		len = (len&(^0x80))*0x100 + int(b)
	}

	// read the string
	result := ""
	buf := make([]byte, len)
	err = r.Read(buf)
	if err != nil {
		return "", err
	}
	result = string(buf)
	return result, nil

}

// expects a string in the format of
func parseHostString(hostString string) (string, string, string, error) {

	if hostString[:len("rsync://")] != "rsync://" {
		return "", "", "", fmt.Errorf("invalid host string: %s. Expected format: rsync://host:port/module", hostString)
	}

	hostString = hostString[len("rsync://"):]
	addrParts := strings.Split(hostString, "/")
	if len(addrParts) != 2 {
		return "", "", "", fmt.Errorf("invalid host string: %s. Expected format: rsync://host:port/module", hostString)
	}

	// this one is easy
	module := addrParts[1]

	// now split the host:port again
	addrParts = strings.Split(addrParts[0], ":")
	if len(addrParts) != 2 {
		return "", "", "", fmt.Errorf("invalid host string: %s. Expected format: rsync://host:port/module", hostString)
	}

	host := addrParts[0]
	port := addrParts[1]

	return host, port, module, nil
}

func getDigestLen(digest string) (int, error) {
	digest = strings.ToLower(digest)

	if digest == "xxh64" {
		return 8, nil
	}
	if digest == "sha1" {
		return 20, nil
	}

	return -1, fmt.Errorf("Unknown digest type")
}

// NewRsyncServer creates a new rsync server connection. It sets the fields that are not
// needed for the server to a default value
func NewRsyncServer(conn net.Conn) *RsyncClient {
	return &RsyncClient{conn: conn, protocolVersion: 31, Module: "", digest: "", digestLen: 0, prevPositiveOutbound: -1, prevNegativeOutbound: 1, prevPositiveInbound: -1, prevNegativeInbound: -1}
}

// NewRsyncClient creates a new rsync client
func NewRsyncClient(hoststring string, protocolVersion string, digest string) (*RsyncClient, error) {
	protocolVersionInt, err := strconv.Atoi(protocolVersion)
	if err != nil {
		return nil, fmt.Errorf("invalid protocol version: %s: %v", protocolVersion, err.Error())
	}

	host, port, module, err := parseHostString(hoststring)
	if err != nil {
		return nil, err
	}

	conn, err := net.Dial("tcp", host+":"+port)
	if err != nil {
		return nil, err
	}

	digestLen, err := getDigestLen(digest)

	if err != nil {
		return nil, err
	}

	client := &RsyncClient{conn: conn, protocolVersion: protocolVersionInt,
		Module: module, digest: digest, digestLen: digestLen}

	// the state regarding the global variables is reset here
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/io.c#L2244
	client.prevPositiveOutbound = -1
	client.prevNegativeOutbound = -1
	client.prevPositiveInbound = -1
	client.prevNegativeInbound = -1

	// read the greetings - we don't care about the content
	_, err = client.ReadLine()
	if err != nil {
		return nil, err
	}

	// send the header
	err = client.WriteLineOld(fmt.Sprintf("@RSYNCD: %d.%d %s", protocolVersionInt, 0, digest))
	if err != nil {
		return nil, err
	}

	// send the module
	err = client.WriteLineOld(module)
	if err != nil {
		return nil, err
	}

	// now consume lines until we read @RSYNCD: OK
	for {
		line, err := client.ReadLine()
		if err != nil {
			return nil, err
		}

		// there might be lines that do not match rsyncd: OK. In that case it's the message of the
		// day that we don't care about.
		if line == "@RSYNCD: OK" {
			break
		}

		if strings.Contains(line, "@ERROR") {
			return nil, fmt.Errorf("error while connecting to rsync daemon: %s", line)
		}

		if strings.Contains(line, "@RSYNCD: AUTHREQD") {
			return nil, fmt.Errorf("authentication required. Not implemented yet")
		}

	}

	// if the error is nil, we just established a connection to the rsync daemon
	return client, err
}

// Close closes the connection to the rsync daemon
func (r *RsyncClient) Close() error {
	return r.conn.Close()
}

// SetupProtocol exchanges protocol information with the rsync daemon after the connection has been
// established.
// There might be different reads and writes depending on the protocol version and options
// that were sent.
// The code is here:
// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/compat.c#L572
func (r *RsyncClient) SetupProtocol() error {
	// The server will send us the compat_flags as a varint. We don't care about the value,
	// but consume it anyway.
	_, err := r.ReadVarInt()
	if err != nil {
		return err
	}

	// At this point we will be in digest negotiation. The Server will send us the digests it supports
	// r.ReadVString() to consume
	// TODO: We can parse this string to automatically determine MAX_DIGEST_LEN of the server
	_, err = r.ReadVString()
	if err != nil {
		return err
	}
	// fmt.Printf("[*] Supported digests: %s\n", digests)
	// Send a line back with only md4 to force md4 digest
	r.WriteVString(r.digest)

	// we will now receive the checksum seed. We already know it (1337) so we can just consume it
	_, err = r.ReadInt()
	if err != nil {
		return err
	}

	// from here on out, the protocol is multiplexed.
	r.inMultiplexed = true
	r.outMultiplexed = true

	// last but not least, the server expects us to send filter lists:
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/exclude.c#L1681
	// This looks like a good place for heap grooming
	// for now we just send a 0 to make sure the server doesn't hang and sends us the file list
	// TODO: There may have been an additional WriteRawInt(0) here which seems to have caused
	// a bug.
	// return r.WriteRawInt(0)
	return nil

}

// ReadFileList reads the file list from the server
func (r *RsyncClient) ReadFileList() ([]File, error) {

	// unfortunately, rsync does some things with global variables
	var prevMode uint32

	// it keeps track of the current directory in a global variable too
	lastname := ""

	result := []File{}
	// this parsing loop corresponds to
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/flist.c#L2601
	for {
		// because we set the `v` string as part of the client options
		// the flags are sent as a varint
		flags, err := r.ReadVarInt()
		if err != nil {
			return nil, err
		}

		// if flags is 0, we are done
		if flags == 0 {
			// in this mode, another varint is sent which contains a potential error
			errCode, err := r.ReadVarInt()
			if err != nil {
				return nil, err
			}
			if errCode != 0 {
				return nil, fmt.Errorf("error while reading file list: %d", errCode)
			}

			break
		}

		// we're not following `recv_file_entry`
		// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/flist.c#L682

		// there will be a few annoying conditional reads that we don't care about but need to
		// consume
		var l1 int
		var l2 int
		if flags&xmitSameName != 0 {
			// consume the l1 value
			l1Tmp, err := r.ReadByte()
			if err != nil {
				return nil, err
			}
			l1 = int(l1Tmp)
		}

		if flags&xmitLongName != 0 {
			// read l2 value as a varint
			l2, err = r.ReadVarint30()
			if err != nil {
				return nil, err
			}
		} else {
			// read l2 value as a byte
			l2Tmp, err := r.ReadByte()
			if err != nil {
				return nil, err
			}
			l2 = int(l2Tmp)
		}

		nameBuf := make([]byte, l2)
		err = r.Read(nameBuf)
		if err != nil {
			return nil, err
		}
		name := lastname[:l1] + string(nameBuf)
		lastname = name

		// we don't care about hard links but if this is one, we need to consume all the bytes
		if r.protocolVersion >= 30 && bitsSetAndUnset(flags, xmitHLinked, xmitHLinkedFirst) {
			// now consume the hlink index, we don't care about this value
			_, err = r.ReadVarInt()
			if err != nil {
				return nil, err
			}
		}

		// read at least 3 bytes for the file length
		fileLength, err := r.ReadVarLong30(3)

		// next comes the modification time. We don't care, just consume it
		if (flags & xmitSameTime) == 0 {

			if r.protocolVersion >= 30 {
				_, err := r.ReadVarLong30(4)
				if err != nil {
					return nil, err
				}
			} else {
				_, err := r.ReadInt()
				if err != nil {
					return nil, err
				}
			}
		}

		if flags&xmitModNsec != 0 {
			// we don't care about nanoseconds, just consume the varint
			_, err := r.ReadVarInt()
			if err != nil {
				return nil, err
			}
		}

		mode := prevMode
		if flags&xmitSameMode == 0 {
			modeTmp, err := r.ReadInt()
			if err != nil {
				return nil, err
			}
			mode = uint32(modeTmp)
			prevMode = mode
		}

		file := &File{
			Name:   name,
			Size:   int(fileLength),
			Mode:   uint32(mode),
			Digest: nil,
		}

		// since we set --checksum we will receive the checksum of this file if it's a regular file
		if util.IsReg(file.Mode) {
			digestBuf := make([]byte, r.digestLen)
			err = r.Read(digestBuf)
			if err != nil {
				return nil, err
			}
			file.Digest = digestBuf
		}

		// we only care about files and directories
		if util.IsReg(file.Mode) || util.IsDir(file.Mode) {
			result = append(result, *file)
		}

	}

	if len(result) == 1 {
		panic("Either there are no files on the server or the server just sent a recursive file list, for which we don't have support yet")
	}

	// the rsync client sorts the file list by name. This is relevant because we want our file lists
	// to be synchronized with the server.
	sort.Slice(result, func(i, j int) bool {

		// this is always the smallest element in the list
		if result[i].Name == "." {
			return true
		}

		if result[j].Name == "." {
			return false
		}

		dirname1 := C.CString(filepath.Dir(result[i].Name))
		dirname2 := C.CString(filepath.Dir(result[j].Name))
		basename1 := C.CString(filepath.Base(result[i].Name))
		basename2 := C.CString(filepath.Base(result[j].Name))
		defer C.free(unsafe.Pointer(dirname1))
		defer C.free(unsafe.Pointer(dirname2))
		defer C.free(unsafe.Pointer(basename1))
		defer C.free(unsafe.Pointer(basename2))

		res := C.f_name_cmp(dirname1, dirname2, basename1, basename2, C.int(result[i].Mode), C.int(result[j].Mode), C.int(r.protocolVersion))
		return int(res) == -1
	})

	return result, nil

}

const chunkSize = (32 * 1024)

// ReceiveToken receives a single token from the server
// Since we're using no compression, this logic maps to
// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/token.c#L281
func (r *RsyncClient) ReceiveToken() (int, []byte, error) {
	if r.residue == 0 {
		i, err := r.ReadInt()
		if err != nil {
			return 0, nil, err
		}

		tokenValue := int32(i)

		// if i is negative, it means that this token matched a checksum
		// if i is 0, it means there is no more data
		if tokenValue <= 0 {
			return i, nil, nil
		}
		r.residue = i
	}

	n := min(r.residue, chunkSize)
	r.residue -= n
	buf := make([]byte, n)
	err := r.Read(buf)
	if err != nil {
		return 0, nil, err
	}
	return n, buf, nil

}

// SendToken sends uncompressed data to the client
func (r *RsyncClient) SendToken(buf []byte) error {
	toSend := len(buf)
	len := 0
	for len < toSend {
		n1 := min(chunkSize, toSend-len)
		r.WriteRawInt(n1)
		r.Write(buf[len : len+n1])
		len += n1
	}

	// send a 0 to indicate that there is no more data
	r.WriteRawInt(0)
	return nil

}

// ReceiveDeflateToken receives tokens from a flate stream
// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/token.c#L500
func (r *RsyncClient) ReceiveDeflateToken() (int, []byte, error) {
	flag := 0
	for {
		if r.flateState.recvState == rInit {
			r.flateState.recvState = rIdle
			r.flateState.rxToken = 0
		}

		if r.flateState.recvState == rIdle || r.flateState.recvState == rInflated {

			if r.flateState.savedFlag != 0 {
				flag = r.flateState.savedFlag & 0xff
				r.flateState.savedFlag = 0
			} else {
				flagB, err := r.ReadByte()
				if err != nil {
					return 0, nil, err
				}
				flag = int(flagB)
			}

			if (flag & 0xC0) == deflatedData {

				flagLength := (flag & 0x3f) << 8
				blength, err := r.ReadByte()
				if err != nil {
					return 0, nil, err
				}
				flagLength += int(blength)
				cbuf := make([]byte, flagLength)
				err = r.Read(cbuf)
				if err != nil {
					return 0, nil, err
				}
				r.flateState.recvState = rInflating

				// for now, we just want to get a signal from the stream to check if we got data or
				// matched a token
				return 1, nil, nil
				// continue
			}

			// the end, no more data after this
			if flag == 0 {
				r.flateState.recvState = rInit
				return 0, nil, nil
			}

			if flag&tokenRel != 0 {
				r.flateState.rxToken += flag & 0x3f
				flag >>= 6
			} else {
				rxToken, err := r.ReadInt()
				if err != nil {
					return 0, nil, err
				}
				r.flateState.rxToken = rxToken
			}

			return -1 - r.flateState.rxToken, nil, nil
		}

	}
}

// DownloadFile downloads a single file from the server
func (r *RsyncClient) DownloadFile(fileNdx int, file *File) ([]byte, error) {
	// First, we write the the ndx and the iflags:
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/sender.c#L230
	err := r.WriteNdx(fileNdx)
	if err != nil {
		return nil, err
	}

	// for the iflags, set the ITEM_TRANSFER bit
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/rsync.h#L111
	iflags := ItemTransferFlag
	err = r.WriteShortInt(iflags)
	if err != nil {
		return nil, err
	}

	// now we are going to send checksums to the server. Since we don't have the files, just
	// set everything to 0
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/sender.c#L345

	// count is 0
	r.WriteRawInt(0)

	// blength is the size of the file
	r.WriteRawInt(file.Size)

	// s2 length
	r.WriteRawInt(r.digestLen)

	// remainder is 0
	err = r.WriteRawInt(5)
	if err != nil {
		return nil, err
	}

	// no more data to write since count is 0

	// for some reason, the server will now send us the ndx and iflags back again, although it
	// never modifies them. Let's consume the data
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/sender.c#L409-L410
	_, err = r.ReadNdx()
	if err != nil {
		return nil, err
	}
	_, err = r.ReadShortInt()
	if err != nil {
		return nil, err
	}

	// it will now send us the sum head back as well
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/sender.c#L409-L410

	// we already know count is going to be 0 so we can just consume it
	r.ReadInt()

	// blength is the size of the file. Already known as well
	r.ReadInt()

	// s2length as well
	r.ReadInt()

	// finally remainder. Check for an error here to see if our connection is still intact
	_, err = r.ReadInt()
	if err != nil {
		return nil, err
	}

	result := []byte{}
	// match the logic in https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/receiver.c#L314
	for {
		n, buf, err := r.ReceiveToken()
		if err != nil {
			return nil, err
		}

		if n < 0 {
			panic("Received a negative number in DownloadFiles. This would indicate that a checksum matched, which should not be possible.")
		}

		if n == 0 {
			break
		}

		result = append(result, buf...)
	}

	// last but not least, download the digest
	// https://github.com/RsyncProject/rsync/blob/9615a2492bbf96bc145e738ebff55bbb91e0bbee/receiver.c#L325
	digestBuf := make([]byte, r.digestLen)
	err = r.Read(digestBuf)
	if err != nil {
		return nil, err
	}

	if !slices.Equal(digestBuf, file.Digest) {
		return nil, fmt.Errorf("Digest of file %s does not match. Expected: %v, got: %v", file.Name, file.Digest, digestBuf)
	}

	return result, nil
}
