// Copyright 2014, 2015 by Sascha L. Teichmann
// Use of this source code is governed by the MIT license
// that can be found in the LICENSE file.

package main

import (
	"database/sql"
	"log"
	"sync"

	_ "github.com/mattn/go-sqlite3"

	"bitbucket.org/s_l_teichmann/mtsatellite/common"
)

var globalLock sync.RWMutex

const (
	deleteSQL = "DELETE FROM blocks WHERE pos = ?"
	fetchSQL  = "SELECT data FROM blocks WHERE pos = ?"
	existsSQL = "SELECT 1 FROM blocks WHERE pos = ?"
	updateSQL = "UPDATE blocks SET data = ? WHERE pos = ?"
	insertSQL = "INSERT INTO blocks (pos, data) VALUES (?, ?)"
	countSQL  = "SELECT count(*) FROM blocks"
	keysSQL   = "SELECT pos FROM blocks"
	rangeSQL  = "SELECT pos, data FROM blocks WHERE pos BETWEEN ? AND ? ORDER BY pos"
)

type SQLiteBackend struct {
	db            *sql.DB
	encoder       common.KeyEncoder
	decoder       common.KeyDecoder
	changeTracker *changeTracker
	interleaved   bool
	coverage      *common.Coverage3D
	existsStmt    *sql.Stmt
	deleteStmt    *sql.Stmt
	fetchStmt     *sql.Stmt
	insertStmt    *sql.Stmt
	updateStmt    *sql.Stmt
	countStmt     *sql.Stmt
	keysStmt      *sql.Stmt
	rangeStmt     *sql.Stmt
}

type SQLiteSession struct {
	backend *SQLiteBackend
	tx      *sql.Tx
}

func (sqlb *SQLiteBackend) NewSession() (Session, error) {
	return &SQLiteSession{sqlb, nil}, nil
}

func (ss *SQLiteSession) Close() error {
	t := ss.tx
	if t != nil {
		ss.tx = nil
		return t.Rollback()
	}
	return nil
}

func NewSQLiteBackend(
	path string,
	changeTracker *changeTracker, interleaved bool) (sqlb *SQLiteBackend, err error) {

	res := SQLiteBackend{interleaved: interleaved, changeTracker: changeTracker}

	if res.db, err = sql.Open("sqlite3", path); err != nil {
		return
	}

	if res.existsStmt, err = res.db.Prepare(existsSQL); err != nil {
		res.closeAll()
		return
	}

	if res.fetchStmt, err = res.db.Prepare(fetchSQL); err != nil {
		res.closeAll()
		return
	}

	if res.deleteStmt, err = res.db.Prepare(deleteSQL); err != nil {
		res.closeAll()
		return
	}

	if res.insertStmt, err = res.db.Prepare(insertSQL); err != nil {
		res.closeAll()
		return
	}

	if res.updateStmt, err = res.db.Prepare(updateSQL); err != nil {
		res.closeAll()
		return
	}

	if res.countStmt, err = res.db.Prepare(countSQL); err != nil {
		res.closeAll()
		return
	}

	if res.keysStmt, err = res.db.Prepare(keysSQL); err != nil {
		res.closeAll()
		return
	}

	if res.rangeStmt, err = res.db.Prepare(rangeSQL); err != nil {
		res.closeAll()
		return
	}

	if interleaved {
		res.encoder = common.EncodeStringToBytesFromInterleaved
		res.decoder = common.DecodeStringFromBytesToInterleaved
	} else {
		res.encoder = common.EncodeStringToBytes
		res.decoder = common.DecodeStringFromBytes
	}

	if !interleaved {
		if err = res.buildCoverage(); err != nil {
			return
		}
	}

	sqlb = &res
	return
}

func (sqlb *SQLiteBackend) buildCoverage() (err error) {
	log.Println("INFO: Start building coverage index (this may take some time)...")
	sqlb.coverage = common.NewCoverage3D()

	var rows *sql.Rows
	if rows, err = sqlb.keysStmt.Query(); err != nil {
		return
	}
	defer rows.Close()

	for rows.Next() {
		var key int64
		if err = rows.Scan(&key); err != nil {
			return
		}
		sqlb.coverage.Insert(common.PlainToCoord(key))
	}
	err = rows.Err()
	log.Println("INFO: Finished building coverage index.")
	return
}

func closeStmt(stmt **sql.Stmt) error {
	s := *stmt
	if s != nil {
		*stmt = nil
		return s.Close()
	}
	return nil
}

func closeDB(db **sql.DB) error {
	d := *db
	if d != nil {
		*db = nil
		return d.Close()
	}
	return nil
}

func (sqlb *SQLiteBackend) closeAll() error {
	closeStmt(&sqlb.deleteStmt)
	closeStmt(&sqlb.fetchStmt)
	closeStmt(&sqlb.insertStmt)
	closeStmt(&sqlb.updateStmt)
	closeStmt(&sqlb.existsStmt)
	closeStmt(&sqlb.countStmt)
	closeStmt(&sqlb.keysStmt)
	closeStmt(&sqlb.rangeStmt)
	return closeDB(&sqlb.db)
}

func (sqlb *SQLiteBackend) Shutdown() error {
	globalLock.Lock()
	defer globalLock.Unlock()

	return sqlb.closeAll()
}

func (ss *SQLiteSession) txStmt(stmt *sql.Stmt) *sql.Stmt {
	if ss.tx != nil {
		return ss.tx.Stmt(stmt)
	}
	return stmt
}

func (ss *SQLiteSession) Del(hash, key []byte) (success bool, err error) {
	var pos int64
	if pos, err = ss.backend.decoder(key); err != nil {
		return
	}

	globalLock.Lock()
	defer globalLock.Unlock()

	existsStmt := ss.txStmt(ss.backend.existsStmt)
	var x int
	err2 := existsStmt.QueryRow(pos).Scan(&x)

	if err2 == sql.ErrNoRows {
		success = false
		return
	}
	if err2 != nil {
		err = err2
		return
	}

	success = true

	deleteStmt := ss.txStmt(ss.backend.deleteStmt)
	_, err = deleteStmt.Exec(pos)

	return
}

func (ss *SQLiteSession) Fetch(hash, key []byte) (data []byte, err error) {
	var pos int64
	if pos, err = ss.backend.decoder(key); err != nil {
		return
	}

	globalLock.RLock()
	defer globalLock.RUnlock()

	fetchStmt := ss.txStmt(ss.backend.fetchStmt)
	err2 := fetchStmt.QueryRow(pos).Scan(&data)
	if err2 == sql.ErrNoRows {
		return
	}
	err = err2
	return
}

func (ss *SQLiteSession) InTransaction() bool {
	return ss.tx != nil
}

func (ss *SQLiteSession) Store(hash, key, value []byte) (exists bool, err error) {
	var pos int64
	if pos, err = ss.backend.decoder(key); err != nil {
		return
	}

	globalLock.Lock()
	defer globalLock.Unlock()

	existsStmt := ss.txStmt(ss.backend.existsStmt)
	var x int
	err2 := existsStmt.QueryRow(pos).Scan(&x)

	if err2 == sql.ErrNoRows {
		exists = false
	} else if err2 != nil {
		err = err2
		return
	} else {
		exists = true
	}

	if exists {
		updateStmt := ss.txStmt(ss.backend.updateStmt)
		_, err = updateStmt.Exec(value, pos)
	} else {
		insertStmt := ss.txStmt(ss.backend.insertStmt)
		_, err = insertStmt.Exec(pos, value)
	}
	if err != nil {
		return
	}
	// This technically too early because this is done in a transactions
	// which are commited (and possible fail) later.
	if ss.backend.changeTracker != nil || ss.backend.coverage != nil {
		c := common.PlainToCoord(pos)
		if ss.backend.coverage != nil && !exists {
			ss.backend.coverage.Insert(c)
		}
		if ss.backend.changeTracker != nil {
			ss.backend.changeTracker.BlockChanged(c)
		}
	}

	return
}

func (ss *SQLiteSession) BeginTransaction() (err error) {
	if ss.tx != nil {
		log.Println("WARN: Already running transaction.")
		return nil
	}

	globalLock.Lock()
	defer globalLock.Unlock()
	ss.tx, err = ss.backend.db.Begin()
	return
}

func (ss *SQLiteSession) CommitTransaction() error {

	tx := ss.tx
	if tx == nil {
		log.Println("WARN: No transaction running.")
		return nil
	}

	globalLock.Lock()
	defer globalLock.Unlock()
	ss.tx = nil
	return tx.Commit()
}

func (ss *SQLiteSession) AllKeys(
	hash []byte,
	done <-chan struct{}) (<-chan []byte, int, error) {
	globalLock.RLock()

	countStmt := ss.txStmt(ss.backend.countStmt)
	var n int
	var err error
	if err = countStmt.QueryRow().Scan(&n); err != nil {
		if err == sql.ErrNoRows {
			err = nil
		}
		globalLock.RUnlock()
		return nil, n, err
	}

	keysStmt := ss.txStmt(ss.backend.keysStmt)
	var rows *sql.Rows
	if rows, err = keysStmt.Query(); err != nil {
		globalLock.RUnlock()
		return nil, n, err
	}

	keys := make(chan []byte)
	go func() {
		defer globalLock.RUnlock()
		defer rows.Close()
		defer close(keys)
		var err error
		for rows.Next() {
			var key int64
			if err = rows.Scan(&key); err != nil {
				log.Printf("WARN: %s\n", err)
				break
			}
			var encoded []byte
			if encoded, err = ss.backend.encoder(key); err != nil {
				log.Printf("Cannot encode key: %d %s\n", key, err)
				break
			}
			select {
			case keys <- encoded:
			case <-done:
				return
			}
		}
	}()

	return keys, n, nil
}

func (ss *SQLiteSession) SpatialQuery(
	hash, first, second []byte,
	done <-chan struct{}) (<-chan Block, error) {

	if ss.backend.interleaved {
		return ss.interleavedSpatialQuery(first, second, done)
	}

	return ss.plainSpatialQuery(first, second, done)
}

func (ss *SQLiteSession) interleavedSpatialQuery(
	first, second []byte,
	done <-chan struct{}) (<-chan Block, error) {

	var (
		firstKey  int64
		secondKey int64
		err       error
	)
	if firstKey, err = common.DecodeStringFromBytes(first); err != nil {
		return nil, err
	}
	if secondKey, err = common.DecodeStringFromBytes(second); err != nil {
		return nil, err
	}
	c1 := common.ClipCoord(common.PlainToCoord(firstKey))
	c2 := common.ClipCoord(common.PlainToCoord(secondKey))
	c1, c2 = common.MinCoord(c1, c2), common.MaxCoord(c1, c2)

	blocks := make(chan Block)

	globalLock.RLock()

	go func() {
		defer close(blocks)
		defer globalLock.RUnlock()
		zmin, zmax := common.CoordToInterleaved(c1), common.CoordToInterleaved(c2)
		// Should not be necessary.
		zmin, zmax = common.Order64(zmin, zmax)
		cub := common.Cuboid{P1: c1, P2: c2}
		rangeStmt := ss.txStmt(ss.backend.rangeStmt)

		zcode := zmin

	loop:

		rows, err := rangeStmt.Query(zcode, zmax)
		if err != nil {
			log.Printf("error: fetching range failed: %s\n", err)
			return
		}

		for rows.Next() {
			var data []byte
			if err = rows.Scan(&zcode, &data); err != nil {
				rows.Close()
				log.Printf("error: scanning row failed: %s\n", err)
				return
			}
			c := common.InterleavedToCoord(zcode)
			if cub.Contains(c) {
				key := common.StringToBytes(common.CoordToPlain(c))
				//fmt.Printf("sending: %q\n", c)
				select {
				case blocks <- Block{Key: key, Data: data}:
				case <-done:
					return
				}
			} else {
				if err = rows.Close(); err != nil {
					log.Printf("error: closing range failed: %s\n", err)
					return
				}
				zcode = common.BigMin(zmin, zmax, zcode)
				goto loop
			}
		}

		if err = rows.Err(); err != nil {
			log.Printf("error: iterating range failed: %s\n", err)
		}

		if err = rows.Close(); err != nil {
			log.Printf("error: closing range failed: %s\n", err)
		}
	}()

	return blocks, nil
}

func (ss *SQLiteSession) plainSpatialQuery(
	first, second []byte,
	done <-chan struct{}) (<-chan Block, error) {

	var (
		firstKey  int64
		secondKey int64
		err       error
	)
	if firstKey, err = common.DecodeStringFromBytes(first); err != nil {
		return nil, err
	}
	if secondKey, err = common.DecodeStringFromBytes(second); err != nil {
		return nil, err
	}
	c1 := common.PlainToCoord(firstKey)
	c2 := common.PlainToCoord(secondKey)
	c1, c2 = common.MinCoord(c1, c2), common.MaxCoord(c1, c2)

	blocks := make(chan Block)

	globalLock.RLock()

	go func() {
		defer globalLock.RUnlock()
		defer close(blocks)
		rangeStmt := ss.txStmt(ss.backend.rangeStmt)

		send := func(rows *sql.Rows, err error) bool {
			if err != nil {
				log.Printf("Error in range query: %s\n", err)
				return false
			}
			defer rows.Close()

			for rows.Next() {
				var key int64
				var data []byte
				if err = rows.Scan(&key, &data); err != nil {
					log.Printf("Error in scanning row: %s\n", err)
					return false
				}
				var encodedKey []byte
				if encodedKey, err = common.EncodeStringToBytes(key); err != nil {
					log.Printf("Key encoding failed: %s\n", err)
					return false
				}
				select {
				case blocks <- Block{Key: encodedKey, Data: data}:
				case <-done:
					return false
				}
			}
			if err = rows.Err(); err != nil {
				log.Printf("Error in range query: %s\n", err)
				return false
			}
			return true
		}

		var a, b common.Coord

		for _, r := range ss.backend.coverage.Query(c1, c2) {
			a.Z, b.Z = int16(r.Z), int16(r.Z)
			a.X, b.X = int16(r.X1), int16(r.X2)
			// log.Printf("y1 y2 x1 x2 z: %d %d, %d %d, %d\n", r.Y1, r.Y2, r.X1, r.X2, r.Z)
			for a.Y = r.Y2; a.Y >= r.Y1; a.Y-- {
				b.Y = a.Y
				from, to := common.CoordToPlain(a), common.CoordToPlain(b)
				if !send(rangeStmt.Query(from, to)) {
					return
				}
			}
		}
	}()

	return blocks, nil
}