// Copyright 2014, 2015 by Sascha L. Teichmann
// Use of this source code is governed by the MIT license
// that can be found in the LICENSE file.

package main

import (
	"log"
	"sync"

	"bitbucket.org/s_l_teichmann/mtsatellite/common"

	leveldb "github.com/jmhodges/levigo"
)

type LevelDBBackend struct {
	cache       *leveldb.Cache
	db          *leveldb.DB
	interleaved bool
	coverage    *common.Coverage3D
	encoder     common.KeyTranscoder
	decoder     common.KeyTranscoder

	changeTracker *changeTracker
	mutex         sync.RWMutex
}

type LevelDBSession struct {
	backend *LevelDBBackend
	tx      *leveldb.WriteBatch
}

func NewLeveDBBackend(
	path string,
	changeTracker *changeTracker,
	interleaved bool,
	cacheSize int) (ldb *LevelDBBackend, err error) {

	opts := leveldb.NewOptions()

	var cache *leveldb.Cache
	if cacheSize > 0 {
		cache = leveldb.NewLRUCache(cacheSize * 1024 * 1024)
		opts.SetCache(cache)
	}

	opts.SetCreateIfMissing(true)

	var db *leveldb.DB
	if db, err = leveldb.Open(path, opts); err != nil {
		if cache != nil {
			cache.Close()
		}
		return
	}
	var (
		encoder common.KeyTranscoder
		decoder common.KeyTranscoder
	)

	if interleaved {
		encoder = common.TranscodeInterleavedToPlain
		decoder = common.TranscodePlainToInterleaved
	} else {
		encoder = common.IdentityTranscoder
		decoder = common.IdentityTranscoder
	}

	ldb = &LevelDBBackend{
		cache:         cache,
		db:            db,
		interleaved:   interleaved,
		encoder:       encoder,
		decoder:       decoder,
		changeTracker: changeTracker}

	if !interleaved {
		if err = ldb.buildCoverage(); err != nil {
			ldb.Shutdown()
			ldb = nil
			return
		}
	}
	return
}

func (ldb *LevelDBBackend) buildCoverage() error {
	log.Println("INFO: Start building coverage index (this may take some time)...")

	coverage := common.NewCoverage3D()

	ro := leveldb.NewReadOptions()
	defer ro.Close()
	ro.SetFillCache(false)
	it := ldb.db.NewIterator(ro)
	it.SeekToFirst()
	for ; it.Valid(); it.Next() {
		c, err := common.DecodeStringBytesToCoord(it.Key())
		if err != nil {
			return err
		}
		coverage.Insert(c)
	}
	if err := it.GetError(); err != nil {
		return err
	}
	ldb.coverage = coverage
	log.Println("INFO: Finished building coverage index.")
	return nil
}

func (ldb *LevelDBBackend) doRead(f func(db *leveldb.DB)) {
	ldb.mutex.RLock()
	f(ldb.db)
	ldb.mutex.RUnlock()
}

func (ldb *LevelDBBackend) doWrite(f func(db *leveldb.DB)) {
	ldb.mutex.Lock()
	f(ldb.db)
	ldb.mutex.Unlock()
}

func (ldb *LevelDBBackend) NewSession() (Session, error) {
	return &LevelDBSession{ldb, nil}, nil
}

func (ldbs *LevelDBSession) Close() error {
	if ldbs.tx != nil {
		ldbs.tx.Close()
	}
	return nil
}

func (ldb *LevelDBBackend) Shutdown() error {
	ldb.db.Close()
	if ldb.cache != nil {
		ldb.cache.Close()
	}
	return nil
}

func (ldbs *LevelDBSession) Del(hash, key []byte) (success bool, err error) {
	if key, err = ldbs.backend.decoder(key); err != nil {
		return
	}
	ldbs.backend.doWrite(func(db *leveldb.DB) {
		ro := leveldb.NewReadOptions()
		defer ro.Close()
		var data []byte
		data, err = ldbs.backend.db.Get(ro, key)
		if err != nil {
			return
		}
		if data == nil {
			success = false
			return
		}
		success = true
		wo := leveldb.NewWriteOptions()
		defer wo.Close()
		err = ldbs.backend.db.Delete(wo, key)
	})
	return
}

func (ldbs *LevelDBSession) Fetch(hash, key []byte) (value []byte, err error) {
	if key, err = ldbs.backend.decoder(key); err != nil {
		return
	}
	ldbs.backend.doRead(func(db *leveldb.DB) {
		ro := leveldb.NewReadOptions()
		value, err = ldbs.backend.db.Get(ro, key)
		//if err != nil {
		//	log.Printf("Fetch key '%s' failed.\n", key)
		//} else {
		//  log.Printf("Fetch key = '%s' len(value) = %d\n", key, len(value))
		//}
		ro.Close()
	})
	return
}

func (ldbs *LevelDBSession) InTransaction() bool {
	return ldbs.tx != nil
}

func keyExists(db *leveldb.DB, key []byte) (exists bool, err error) {
	ro := leveldb.NewReadOptions()
	defer ro.Close()
	var data []byte
	if data, err = db.Get(ro, key); err != nil {
		return
	}
	exists = data != nil
	return
}

func (ldbs *LevelDBSession) Store(hash, key, value []byte) (exists bool, err error) {
	origKey := key
	if key, err = ldbs.backend.decoder(key); err != nil {
		return
	}
	ldbs.backend.doWrite(func(db *leveldb.DB) {
		if exists, err = keyExists(db, key); err != nil {
			return
		}
		if ldbs.tx != nil {
			ldbs.tx.Put(key, value)
			return
		}
		wo := leveldb.NewWriteOptions()
		err = db.Put(wo, key, value)
		wo.Close()
	})
	if err != nil {
		return
	}
	// This technically too early because this is done in a transactions
	// which are commited (and possible fail) later.
	if ldbs.backend.changeTracker != nil || ldbs.backend.coverage != nil {
		c, err := common.DecodeStringBytesToCoord(origKey)
		if err != nil {
			return exists, err
		}
		if ldbs.backend.coverage != nil && !exists {
			ldbs.backend.coverage.Insert(c)
		}
		if ldbs.backend.changeTracker != nil {
			ldbs.backend.changeTracker.BlockChanged(c)
		}
	}
	return
}

func (ldbs *LevelDBSession) BeginTransaction() error {
	ldbs.tx = leveldb.NewWriteBatch()
	return nil
}

func (ldbs *LevelDBSession) CommitTransaction() (err error) {
	tx := ldbs.tx
	if tx == nil {
		log.Println("WARN: No transaction running.")
		return
	}
	ldbs.tx = nil
	ldbs.backend.doWrite(func(db *leveldb.DB) {
		wo := leveldb.NewWriteOptions()
		wo.SetSync(true)
		err = db.Write(wo, tx)
		wo.Close()
		tx.Close()
	})
	return
}

func (ldbs *LevelDBSession) AllKeys(
	hash []byte,
	done <-chan struct{}) (<-chan []byte, int, error) {

	ldbs.backend.mutex.RLock()

	ro := leveldb.NewReadOptions()
	ro.SetFillCache(false)

	it := ldbs.backend.db.NewIterator(ro)
	it.SeekToFirst()
	var n int
	for ; it.Valid(); it.Next() {
		n++
	}

	if err := it.GetError(); err != nil {
		it.Close()
		ro.Close()
		ldbs.backend.mutex.RUnlock()
		return nil, n, err
	}

	keys := make(chan []byte)

	go func() {
		ldbs.backend.mutex.RUnlock()
		defer ro.Close()
		defer close(keys)
		defer it.Close()
		it.SeekToFirst()
		encoder := ldbs.backend.encoder
		for ; it.Valid(); it.Next() {
			if key, err := encoder(it.Key()); err == nil {
				select {
				case keys <- key:
				case <-done:
					return
				}
			} else {
				log.Printf("WARN: %s\n", err)
				return
			}
		}
		if err := it.GetError(); err != nil {
			log.Printf("WARN: %s\n", err)
		}
	}()

	return keys, n, nil
}

func (ldbs *LevelDBSession) SpatialQuery(
	hash, first, second []byte,
	done <-chan struct{}) (<-chan Block, error) {

	if ldbs.backend.interleaved {
		return ldbs.interleavedSpatialQuery(first, second, done)
	}
	return ldbs.plainSpatialQuery(first, second, done)
}

func (ldbs *LevelDBSession) plainSpatialQuery(
	first, second []byte,
	done <-chan struct{}) (<-chan Block, error) {

	var (
		firstKey  int64
		secondKey int64
		err       error
	)
	if firstKey, err = common.DecodeStringFromBytes(first); err != nil {
		return nil, err
	}
	if secondKey, err = common.DecodeStringFromBytes(second); err != nil {
		return nil, err
	}
	c1 := common.PlainToCoord(firstKey)
	c2 := common.PlainToCoord(secondKey)
	c1, c2 = common.MinCoord(c1, c2), common.MaxCoord(c1, c2)

	blocks := make(chan Block)

	go func() {
		defer close(blocks)
		ldbs.backend.mutex.RLock()
		defer ldbs.backend.mutex.RUnlock()

		ro := leveldb.NewReadOptions()
		defer ro.Close()

		var a, b common.Coord

		for _, r := range ldbs.backend.coverage.Query(c1, c2) {
			a.Z, b.Z = int16(r.Z), int16(r.Z)
			a.X, b.X = int16(r.X1), int16(r.X2)
			for a.Y = r.Y2; a.Y >= r.Y1; a.Y-- {
				b.Y = a.Y
				// The keys in the database are stored and ordered as strings
				// "1", "10", ..., "19", "2", "20", "21" so you cannot use
				// an iterator and assume it is numerical ordered.
				// Each block is fetched with a Get instead.
				for f, t := common.CoordToPlain(a), common.CoordToPlain(b); f <= t; f++ {
					key := common.StringToBytes(f)
					value, err := ldbs.backend.db.Get(ro, key)
					if err != nil {
						log.Printf("get failed: %s\n", err)
						return
					}
					if value != nil {
						select {
						case blocks <- Block{Key: key, Data: value}:
						case <-done:
							return
						}
					}
				}

			}
		}
	}()
	return blocks, nil
}

func (ldbs *LevelDBSession) interleavedSpatialQuery(
	first, second []byte,
	done <-chan struct{}) (<-chan Block, error) {

	var (
		firstKey  int64
		secondKey int64
		err       error
	)
	if firstKey, err = common.DecodeStringFromBytes(first); err != nil {
		return nil, err
	}
	if secondKey, err = common.DecodeStringFromBytes(second); err != nil {
		return nil, err
	}
	c1 := common.ClipCoord(common.PlainToCoord(firstKey))
	c2 := common.ClipCoord(common.PlainToCoord(secondKey))
	c1, c2 = common.MinCoord(c1, c2), common.MaxCoord(c1, c2)

	blocks := make(chan Block)

	go func() {
		defer close(blocks)
		ldbs.backend.mutex.RLock()
		defer ldbs.backend.mutex.RUnlock()

		ro := leveldb.NewReadOptions()
		defer ro.Close()
		ro.SetFillCache(false)

		it := ldbs.backend.db.NewIterator(ro)
		defer it.Close()

		zmin, zmax := common.CoordToInterleaved(c1), common.CoordToInterleaved(c2)
		// Should not be necessary.
		zmin, zmax = common.Order64(zmin, zmax)
		var (
			cub        = common.Cuboid{P1: c1, P2: c2}
			err        error
			encodedKey []byte
		)

		//log.Printf("seeking to: %d\n", zmin)
		it.Seek(common.ToBigEndian(zmin))
		for it.Valid() {
			zcode := common.FromBigEndian(it.Key())

			if zcode > zmax {
				break
			}

			if c := common.InterleavedToCoord(zcode); cub.Contains(c) {
				if encodedKey, err = common.EncodeStringToBytes(common.CoordToPlain(c)); err != nil {
					log.Printf("error encoding key: %s\n", err)
					return
				}
				select {
				case blocks <- Block{Key: encodedKey, Data: it.Value()}:
				case <-done:
					return
				}
				it.Next()
			} else {
				next := common.BigMin(zmin, zmax, zcode)
				//log.Printf("seeking to: %d\n", next)
				it.Seek(common.ToBigEndian(next))
				//log.Printf("seeking done: %d\n", next)
			}
		}
		//log.Println("iterating done")
		if err = it.GetError(); err != nil {
			log.Printf("error while iterating: %s\n", err)
			return
		}
	}()
	return blocks, nil
}