mtsatellite/cmd/mtredisalize/sqlite.go
2015-07-24 10:01:39 +02:00

495 lines
10 KiB
Go

// Copyright 2014 by Sascha L. Teichmann
// Use of this source code is governed by the MIT license
// that can be found in the LICENSE file.
package main
import (
"database/sql"
"log"
"sync"
_ "github.com/mattn/go-sqlite3"
"bitbucket.org/s_l_teichmann/mtsatellite/common"
)
var globalLock sync.RWMutex
const (
fetchSQL = "SELECT data FROM blocks WHERE pos = ?"
existsSQL = "SELECT 1 FROM blocks WHERE pos = ?"
updateSQL = "UPDATE blocks SET data = ? WHERE pos = ?"
insertSQL = "INSERT INTO blocks (pos, data) VALUES (?, ?)"
countSQL = "SELECT count(*) FROM blocks"
keysSQL = "SELECT pos FROM blocks"
rangeSQL = "SELECT pos, data FROM blocks WHERE pos BETWEEN ? AND ? ORDER BY pos"
)
type SqliteBackend struct {
db *sql.DB
encoder common.KeyEncoder
decoder common.KeyDecoder
changeTracker *ChangeTracker
interleaved bool
coverage *common.Coverage3D
existsStmt *sql.Stmt
fetchStmt *sql.Stmt
insertStmt *sql.Stmt
updateStmt *sql.Stmt
countStmt *sql.Stmt
keysStmt *sql.Stmt
rangeStmt *sql.Stmt
}
type SqliteSession struct {
backend *SqliteBackend
tx *sql.Tx
}
func (sqlb *SqliteBackend) NewSession() (Session, error) {
return &SqliteSession{sqlb, nil}, nil
}
func (ss *SqliteSession) Close() error {
t := ss.tx
if t != nil {
ss.tx = nil
return t.Rollback()
}
return nil
}
func NewSqliteBackend(
path string,
changeTracker *ChangeTracker, interleaved bool) (sqlb *SqliteBackend, err error) {
res := SqliteBackend{interleaved: interleaved, changeTracker: changeTracker}
if res.db, err = sql.Open("sqlite3", path); err != nil {
return
}
if res.existsStmt, err = res.db.Prepare(existsSQL); err != nil {
res.closeAll()
return
}
if res.fetchStmt, err = res.db.Prepare(fetchSQL); err != nil {
res.closeAll()
return
}
if res.insertStmt, err = res.db.Prepare(insertSQL); err != nil {
res.closeAll()
return
}
if res.updateStmt, err = res.db.Prepare(updateSQL); err != nil {
res.closeAll()
return
}
if res.countStmt, err = res.db.Prepare(countSQL); err != nil {
res.closeAll()
return
}
if res.keysStmt, err = res.db.Prepare(keysSQL); err != nil {
res.closeAll()
return
}
if res.rangeStmt, err = res.db.Prepare(rangeSQL); err != nil {
res.closeAll()
return
}
if interleaved {
res.encoder = common.EncodeStringToBytesFromInterleaved
res.decoder = common.DecodeStringFromBytesToInterleaved
} else {
res.encoder = common.EncodeStringToBytes
res.decoder = common.DecodeStringFromBytes
}
if !interleaved {
if err = res.buildCoverage(); err != nil {
return
}
}
sqlb = &res
return
}
func (sb *SqliteBackend) buildCoverage() (err error) {
log.Println("INFO: Start building coverage index (this may take some time)...")
sb.coverage = common.NewCoverage3D()
var rows *sql.Rows
if rows, err = sb.keysStmt.Query(); err != nil {
return
}
defer rows.Close()
for rows.Next() {
var key int64
if err = rows.Scan(&key); err != nil {
return
}
sb.coverage.Insert(common.PlainToCoord(key))
}
err = rows.Err()
log.Println("INFO: Finished building coverage index.")
return
}
func closeStmt(stmt **sql.Stmt) error {
s := *stmt
if s != nil {
*stmt = nil
return s.Close()
}
return nil
}
func closeDB(db **sql.DB) error {
d := *db
if d != nil {
*db = nil
return d.Close()
}
return nil
}
func (sqlb *SqliteBackend) closeAll() error {
closeStmt(&sqlb.fetchStmt)
closeStmt(&sqlb.insertStmt)
closeStmt(&sqlb.updateStmt)
closeStmt(&sqlb.existsStmt)
closeStmt(&sqlb.countStmt)
closeStmt(&sqlb.keysStmt)
closeStmt(&sqlb.rangeStmt)
return closeDB(&sqlb.db)
}
func (sqlb *SqliteBackend) Shutdown() error {
globalLock.Lock()
defer globalLock.Unlock()
return sqlb.closeAll()
}
func (ss *SqliteSession) txStmt(stmt *sql.Stmt) *sql.Stmt {
if ss.tx != nil {
return ss.tx.Stmt(stmt)
}
return stmt
}
func (ss *SqliteSession) Fetch(hash, key []byte) (data []byte, err error) {
var pos int64
if pos, err = ss.backend.decoder(key); err != nil {
return
}
globalLock.RLock()
defer globalLock.RUnlock()
fetchStmt := ss.txStmt(ss.backend.fetchStmt)
err2 := fetchStmt.QueryRow(pos).Scan(&data)
if err2 == sql.ErrNoRows {
return
}
err = err2
return
}
func (ss *SqliteSession) InTransaction() bool {
return ss.tx != nil
}
func (ss *SqliteSession) Store(hash, key, value []byte) (exists bool, err error) {
var pos int64
if pos, err = ss.backend.decoder(key); err != nil {
return
}
globalLock.Lock()
defer globalLock.Unlock()
existsStmt := ss.txStmt(ss.backend.existsStmt)
var x int
err2 := existsStmt.QueryRow(pos).Scan(&x)
if err2 == sql.ErrNoRows {
exists = false
} else if err2 != nil {
err = err2
return
} else {
exists = true
}
if exists {
updateStmt := ss.txStmt(ss.backend.updateStmt)
_, err = updateStmt.Exec(value, pos)
} else {
insertStmt := ss.txStmt(ss.backend.insertStmt)
_, err = insertStmt.Exec(pos, value)
}
// This technically too early because this done in transactions
// which are commited (and possible fail) later.
if ss.backend.coverage != nil {
ss.backend.coverage.Insert(common.PlainToCoord(pos))
}
if ss.backend.changeTracker != nil {
ss.backend.changeTracker.BlockChanged(key)
}
return
}
func (ss *SqliteSession) BeginTransaction() (err error) {
if ss.tx != nil {
log.Println("WARN: Already running transaction.")
return nil
}
globalLock.Lock()
defer globalLock.Unlock()
ss.tx, err = ss.backend.db.Begin()
return
}
func (ss *SqliteSession) CommitTransaction() error {
tx := ss.tx
if tx == nil {
log.Println("WARN: No transaction running.")
return nil
}
globalLock.Lock()
defer globalLock.Unlock()
ss.tx = nil
return tx.Commit()
}
func (ss *SqliteSession) AllKeys(
hash []byte,
done chan struct{}) (keys chan []byte, n int, err error) {
globalLock.RLock()
countStmt := ss.txStmt(ss.backend.countStmt)
if err = countStmt.QueryRow().Scan(&n); err != nil {
if err == sql.ErrNoRows {
err = nil
}
globalLock.RUnlock()
return
}
keysStmt := ss.txStmt(ss.backend.keysStmt)
var rows *sql.Rows
if rows, err = keysStmt.Query(); err != nil {
globalLock.RUnlock()
return
}
keys = make(chan []byte)
go func() {
defer globalLock.RUnlock()
defer rows.Close()
defer close(keys)
var err error
for rows.Next() {
var key int64
if err := rows.Scan(&key); err != nil {
log.Printf("WARN: %s\n", err)
break
}
var encoded []byte
if encoded, err = ss.backend.encoder(key); err != nil {
log.Printf("Cannot encode key: %d %s\n", key, err)
break
}
select {
case keys <- encoded:
case <-done:
return
}
}
}()
return
}
func (ss *SqliteSession) SpatialQuery(
hash, first, second []byte,
done chan struct{}) (chan Block, error) {
if ss.backend.interleaved {
return ss.interleavedSpatialQuery(first, second, done)
}
return ss.plainSpatialQuery(first, second, done)
}
func (ss *SqliteSession) interleavedSpatialQuery(
first, second []byte,
done chan struct{}) (blocks chan Block, err error) {
var (
firstKey int64
secondKey int64
)
if firstKey, err = common.DecodeStringFromBytes(first); err != nil {
return
}
if secondKey, err = common.DecodeStringFromBytes(second); err != nil {
return
}
c1 := common.ClipCoord(common.PlainToCoord(firstKey))
c2 := common.ClipCoord(common.PlainToCoord(secondKey))
c1, c2 = common.MinCoord(c1, c2), common.MaxCoord(c1, c2)
blocks = make(chan Block)
globalLock.RLock()
go func() {
defer close(blocks)
defer globalLock.RUnlock()
zmin, zmax := common.CoordToInterleaved(c1), common.CoordToInterleaved(c2)
// Should not be necessary.
zmin, zmax = order(zmin, zmax)
cub := common.Cuboid{P1: c1, P2: c2}
rangeStmt := ss.txStmt(ss.backend.rangeStmt)
var (
err error
rows *sql.Rows
)
loop:
// log.Printf("query %d %d\n", zmin, zmax)
if rows, err = rangeStmt.Query(zmin, zmax); err != nil {
log.Printf("Error in range query: %s\n", err)
return
}
for rows.Next() {
var zcode int64
var data []byte
if err = rows.Scan(&zcode, &data); err != nil {
log.Printf("Error in scanning row: %s\n", err)
rows.Close()
return
}
// log.Printf("zcode: %d\n", zcode)
c := common.InterleavedToCoord(zcode)
if cub.Contains(c) {
var encodedKey []byte
if encodedKey, err = common.EncodeStringToBytes(common.CoordToPlain(c)); err != nil {
log.Printf("Key encoding failed: %s\n", err)
rows.Close()
return
}
select {
case blocks <- Block{Key: encodedKey, Data: data}:
case <-done:
rows.Close()
return
}
} else { // Left the cuboid
// log.Printf("Left cuboid %d\n", zcode)
rows.Close()
zmin = common.BigMin(zmin, zmax, zcode)
goto loop
}
}
if err = rows.Err(); err != nil {
log.Printf("Error in range query: %s\n", err)
}
rows.Close()
}()
return
}
func (ss *SqliteSession) plainSpatialQuery(
first, second []byte,
done chan struct{}) (blocks chan Block, err error) {
var (
firstKey int64
secondKey int64
)
if firstKey, err = common.DecodeStringFromBytes(first); err != nil {
return
}
if secondKey, err = common.DecodeStringFromBytes(second); err != nil {
return
}
c1 := common.PlainToCoord(firstKey)
c2 := common.PlainToCoord(secondKey)
c1, c2 = common.MinCoord(c1, c2), common.MaxCoord(c1, c2)
blocks = make(chan Block)
globalLock.RLock()
go func() {
defer globalLock.RUnlock()
defer close(blocks)
rangeStmt := ss.txStmt(ss.backend.rangeStmt)
send := func(rows *sql.Rows, err error) bool {
if err != nil {
log.Printf("Error in range query: %s\n", err)
return false
}
defer rows.Close()
for rows.Next() {
var key int64
var data []byte
if err = rows.Scan(&key, &data); err != nil {
log.Printf("Error in scanning row: %s\n", err)
return false
}
var encodedKey []byte
if encodedKey, err = common.EncodeStringToBytes(key); err != nil {
log.Printf("Key encoding failed: %s\n", err)
return false
}
select {
case blocks <- Block{Key: encodedKey, Data: data}:
case <-done:
return false
}
}
if err = rows.Err(); err != nil {
log.Printf("Error in range query: %s\n", err)
return false
}
return true
}
var a, b common.Coord
for _, r := range ss.backend.coverage.Query(c1, c2) {
a.Z, b.Z = int16(r.Z), int16(r.Z)
a.X, b.X = int16(r.X1), int16(r.X2)
// log.Printf("y1 y2 x1 x2 z: %d %d, %d %d, %d\n", r.Y1, r.Y2, r.X1, r.X2, r.Z)
for a.Y = r.Y2; a.Y >= r.Y1; a.Y-- {
b.Y = a.Y
from, to := common.CoordToPlain(a), common.CoordToPlain(b)
if !send(rangeStmt.Query(from, to)) {
return
}
}
}
}()
return
}