// Copyright 2014 by Sascha L. Teichmann
// Use of this source code is governed by the MIT license
// that can be found in the LICENSE file.

package main

import (
	"bufio"
	"fmt"
	"log"
	"net"
)

var (
	redisOk          = []byte("+OK\r\n")
	redisError       = []byte("-ERR\r\n")
	redisNoSuchBlock = []byte("$-1\r\n")
	redisCrnl        = []byte("\r\n")
	redisEmptyArray  = []byte("*0\r\n")
	redisQueued      = []byte("+QUEUED\r\n")
)

type Connection struct {
	conn      net.Conn
	session   Session
	boolArray []bool
}

func NewConnection(conn net.Conn, session Session) *Connection {
	return &Connection{
		conn:      conn,
		session:   session,
		boolArray: []bool{}}
}

func (c *Connection) Run() {
	defer func() {
		c.session.Close()
		c.conn.Close()
	}()
	rce := NewRedisCommandExecutor(c)
	r := bufio.NewReaderSize(c.conn, 8*1024)
	parser := NewRedisParser(r, rce)
	parser.Parse()
	log.Println("client disconnected")
}

func logError(err error) {
	log.Printf("ERROR: %s", err)
}

func (c *Connection) Hget(hash, key []byte) bool {

	var err error
	var data []byte
	if data, err = c.session.Fetch(hash, key); err != nil {
		return c.writeError(err)
	}

	return c.writeBlock(data)
}

func (c *Connection) Hset(hash, key, data []byte) bool {

	var err error
	var exists bool
	if exists, err = c.session.Store(hash, key, data); err != nil {
		return c.writeError(err)
	}

	if c.session.InTransaction() {
		c.boolArray = append(c.boolArray, exists)
		return c.writeQueued()
	}

	return c.writeBool(exists)
}

func (c *Connection) Multi() bool {
	if c.session.InTransaction() {
		log.Println("WARN: Already running transaction.")
	} else {
		if err := c.session.BeginTransaction(); err != nil {
			return c.writeError(err)
		}
	}
	return c.writeOk()
}

func (c *Connection) Exec() bool {
	if !c.session.InTransaction() {
		return c.writeEmptyArray()
	}
	arr := c.boolArray
	c.boolArray = []bool{}
	if err := c.session.CommitTransaction(); err != nil {
		return c.writeError(err)
	}
	return c.writeBoolArray(arr)
}

func (c *Connection) Hkeys(hash []byte) bool {
	var (
		err  error
		n    int
		keys chan []byte
		done = make(chan struct{})
	)
	defer close(done)

	if keys, n, err = c.session.AllKeys(hash, done); err != nil {
		return c.writeError(err)
	}

	if n == 0 {
		return c.writeEmptyArray()
	}

	if _, err := c.conn.Write([]byte(fmt.Sprintf("*%d\r\n", n))); err != nil {
		logError(err)
		return false
	}

	for key := range keys {
		if err = c.writeBulkString(key); err != nil {
			logError(err)
			return false
		}
	}
	return true
}

func (c *Connection) HSpatial(hash, first, second []byte) bool {
	var (
		err    error
		blocks chan Block
		done   = make(chan struct{})
	)
	defer close(done)

	if blocks, err = c.session.SpatialQuery(hash, first, second, done); err != nil {
		return c.writeError(err)
	}

	for block := range blocks {
		if err = c.writeBulkString(block.Key); err != nil {
			logError(err)
			return false
		}
		if err = c.writeBulkString(block.Data); err != nil {
			logError(err)
			return false
		}
	}

	if err = c.writeBulkString(nil); err != nil {
		logError(err)
		return false
	}
	return true
}

func (c *Connection) writeError(err error) bool {
	logError(err)
	if _, err = c.conn.Write(redisError); err != nil {
		logError(err)
		return false
	}
	return true
}

func (c *Connection) writeEmptyArray() bool {
	if _, err := c.conn.Write(redisEmptyArray); err != nil {
		logError(err)
		return false
	}
	return true
}

func asInt(b bool) int {
	if b {
		return 1
	}
	return 0
}

func (c *Connection) writeBool(b bool) bool {
	if _, err := c.conn.Write([]byte(fmt.Sprintf(":%d\r\n", asInt(b)))); err != nil {
		logError(err)
		return false
	}
	return true
}

func (c *Connection) writeBoolArray(arr []bool) bool {
	if _, err := c.conn.Write([]byte(fmt.Sprintf("*%d\r\n", len(arr)))); err != nil {
		logError(err)
		return false
	}
	for _, b := range arr {
		if !c.writeBool(b) {
			return false
		}
	}
	return true
}

func (c *Connection) writeOk() bool {
	if _, err := c.conn.Write(redisOk); err != nil {
		logError(err)
		return false
	}
	return true
}

func (c *Connection) writeQueued() bool {
	if _, err := c.conn.Write(redisQueued); err != nil {
		logError(err)
		return false
	}
	return true
}

func (c *Connection) writeBlock(data []byte) bool {
	if err := c.writeBulkString(data); err != nil {
		logError(err)
		return false
	}
	return true
}

func (c *Connection) writeBulkString(data []byte) (err error) {
	con := c.conn
	if data == nil {
		_, err = con.Write(redisNoSuchBlock)
	} else {
		if _, err = con.Write([]byte(fmt.Sprintf("$%d\r\n", len(data)))); err != nil {
			return
		}
		if _, err = con.Write(data); err != nil {
			return
		}
		_, err = con.Write(redisCrnl)
	}
	return
}