// Copyright 2014 by Sascha L. Teichmann // Use of this source code is governed by the MIT license // that can be found in the LICENSE file. package main import ( "database/sql" "log" "sync" _ "github.com/mattn/go-sqlite3" "bitbucket.org/s_l_teichmann/mtsatellite/common" ) var globalLock sync.RWMutex const ( fetchSQL = "SELECT data FROM blocks WHERE pos = ?" existsSQL = "SELECT 1 FROM blocks WHERE pos = ?" updateSQL = "UPDATE blocks SET data = ? WHERE pos = ?" insertSQL = "INSERT INTO blocks (pos, data) VALUES (?, ?)" countSQL = "SELECT count(*) FROM blocks" keysSQL = "SELECT pos FROM blocks" rangeSQL = "SELECT pos, data FROM blocks WHERE pos BETWEEN ? AND ? ORDER BY pos" rangeDuffSQL = "SELECT pos, data FROM blocks WHERE " + "pos BETWEEN ? AND ? OR " + "pos BETWEEN ? AND ? OR " + "pos BETWEEN ? AND ? OR " + "pos BETWEEN ? AND ?" ) type SqliteBackend struct { db *sql.DB encoder common.KeyEncoder decoder common.KeyDecoder changeTracker *ChangeTracker interleaved bool coverage *common.Coverage3D existsStmt *sql.Stmt fetchStmt *sql.Stmt insertStmt *sql.Stmt updateStmt *sql.Stmt countStmt *sql.Stmt keysStmt *sql.Stmt rangeStmt *sql.Stmt } type SqliteSession struct { backend *SqliteBackend tx *sql.Tx } func (sqlb *SqliteBackend) NewSession() (Session, error) { return &SqliteSession{sqlb, nil}, nil } func (ss *SqliteSession) Close() error { t := ss.tx if t != nil { ss.tx = nil return t.Rollback() } return nil } func NewSqliteBackend( path string, changeTracker *ChangeTracker, interleaved bool) (sqlb *SqliteBackend, err error) { res := SqliteBackend{interleaved: interleaved, changeTracker: changeTracker} if res.db, err = sql.Open("sqlite3", path); err != nil { return } if res.existsStmt, err = res.db.Prepare(existsSQL); err != nil { res.closeAll() return } if res.fetchStmt, err = res.db.Prepare(fetchSQL); err != nil { res.closeAll() return } if res.insertStmt, err = res.db.Prepare(insertSQL); err != nil { res.closeAll() return } if res.updateStmt, err = res.db.Prepare(updateSQL); err != nil { res.closeAll() return } if res.countStmt, err = res.db.Prepare(countSQL); err != nil { res.closeAll() return } if res.keysStmt, err = res.db.Prepare(keysSQL); err != nil { res.closeAll() return } var rS string if interleaved { rS = rangeSQL } else { rS = rangeDuffSQL } if res.rangeStmt, err = res.db.Prepare(rS); err != nil { res.closeAll() return } if interleaved { res.encoder = common.EncodeStringToBytesFromInterleaved res.decoder = common.DecodeStringFromBytesToInterleaved } else { res.encoder = common.EncodeStringToBytes res.decoder = common.DecodeStringFromBytes } if !interleaved { if err = res.buildCoverage(); err != nil { return } } sqlb = &res return } func (sb *SqliteBackend) buildCoverage() (err error) { log.Println("INFO: Start building coverage index (this may take some time)...") sb.coverage = common.NewCoverage3D() var rows *sql.Rows if rows, err = sb.keysStmt.Query(); err != nil { return } defer rows.Close() for rows.Next() { var key int64 if err = rows.Scan(&key); err != nil { return } sb.coverage.Insert(common.PlainToCoord(key)) } err = rows.Err() log.Println("INFO: Finished building coverage index.") return } func closeStmt(stmt **sql.Stmt) error { s := *stmt if s != nil { *stmt = nil return s.Close() } return nil } func closeDB(db **sql.DB) error { d := *db if d != nil { *db = nil return d.Close() } return nil } func (sqlb *SqliteBackend) closeAll() error { closeStmt(&sqlb.fetchStmt) closeStmt(&sqlb.insertStmt) closeStmt(&sqlb.updateStmt) closeStmt(&sqlb.existsStmt) closeStmt(&sqlb.countStmt) closeStmt(&sqlb.keysStmt) closeStmt(&sqlb.rangeStmt) return closeDB(&sqlb.db) } func (sqlb *SqliteBackend) Shutdown() error { globalLock.Lock() defer globalLock.Unlock() return sqlb.closeAll() } func (ss *SqliteSession) txStmt(stmt *sql.Stmt) *sql.Stmt { if ss.tx != nil { return ss.tx.Stmt(stmt) } return stmt } func (ss *SqliteSession) Fetch(hash, key []byte) (data []byte, err error) { var pos int64 if pos, err = ss.backend.decoder(key); err != nil { return } globalLock.RLock() defer globalLock.RUnlock() fetchStmt := ss.txStmt(ss.backend.fetchStmt) err2 := fetchStmt.QueryRow(pos).Scan(&data) if err2 == sql.ErrNoRows { return } err = err2 return } func (ss *SqliteSession) InTransaction() bool { return ss.tx != nil } func (ss *SqliteSession) Store(hash, key, value []byte) (exists bool, err error) { var pos int64 if pos, err = ss.backend.decoder(key); err != nil { return } globalLock.Lock() defer globalLock.Unlock() existsStmt := ss.txStmt(ss.backend.existsStmt) var x int err2 := existsStmt.QueryRow(pos).Scan(&x) if err2 == sql.ErrNoRows { exists = false } else if err2 != nil { err = err2 return } else { exists = true } if exists { updateStmt := ss.txStmt(ss.backend.updateStmt) _, err = updateStmt.Exec(value, pos) } else { insertStmt := ss.txStmt(ss.backend.insertStmt) _, err = insertStmt.Exec(pos, value) } // This technically too early because this done in transactions // which are commited (and possible fail) later. if ss.backend.coverage != nil { ss.backend.coverage.Insert(common.PlainToCoord(pos)) } if ss.backend.changeTracker != nil { ss.backend.changeTracker.BlockChanged(key) } return } func (ss *SqliteSession) BeginTransaction() (err error) { if ss.tx != nil { log.Println("WARN: Already running transaction.") return nil } globalLock.Lock() defer globalLock.Unlock() ss.tx, err = ss.backend.db.Begin() return } func (ss *SqliteSession) CommitTransaction() error { tx := ss.tx if tx == nil { log.Println("WARN: No transaction running.") return nil } globalLock.Lock() defer globalLock.Unlock() ss.tx = nil return tx.Commit() } func (ss *SqliteSession) AllKeys(hash []byte, done chan struct{}) (keys chan []byte, n int, err error) { globalLock.RLock() countStmt := ss.txStmt(ss.backend.countStmt) if err = countStmt.QueryRow().Scan(&n); err != nil { if err == sql.ErrNoRows { err = nil } globalLock.RUnlock() return } keysStmt := ss.txStmt(ss.backend.keysStmt) var rows *sql.Rows if rows, err = keysStmt.Query(); err != nil { globalLock.RUnlock() return } keys = make(chan []byte) go func() { defer globalLock.RUnlock() defer rows.Close() defer close(keys) var err error for rows.Next() { var key int64 if err := rows.Scan(&key); err != nil { log.Printf("WARN: %s\n", err) break } var encoded []byte if encoded, err = ss.backend.encoder(key); err != nil { log.Printf("Cannot encode key: %d %s\n", key, err) break } select { case keys <- encoded: case <-done: return } } }() return } func (ss *SqliteSession) SpatialQuery(hash, first, second []byte, done chan struct{}) (chan Block, error) { if ss.backend.interleaved { return ss.interleavedSpatialQuery(first, second, done) } return ss.plainSpatialQuery(first, second, done) } func (ss *SqliteSession) interleavedSpatialQuery(first, second []byte, done chan struct{}) (blocks chan Block, err error) { var ( firstKey int64 secondKey int64 ) if firstKey, err = common.DecodeStringFromBytes(first); err != nil { return } if secondKey, err = common.DecodeStringFromBytes(second); err != nil { return } c1 := common.ClipCoord(common.PlainToCoord(firstKey)) c2 := common.ClipCoord(common.PlainToCoord(secondKey)) c1, c2 = common.MinCoord(c1, c2), common.MaxCoord(c1, c2) blocks = make(chan Block) globalLock.RLock() go func() { defer close(blocks) defer globalLock.RUnlock() zmin, zmax := common.CoordToInterleaved(c1), common.CoordToInterleaved(c2) // Should not be necessary. zmin, zmax = order(zmin, zmax) cub := common.Cuboid{P1: c1, P2: c2} rangeStmt := ss.txStmt(ss.backend.rangeStmt) var ( err error rows *sql.Rows ) loop: // log.Printf("query %d %d\n", zmin, zmax) if rows, err = rangeStmt.Query(zmin, zmax); err != nil { log.Printf("Error in range query: %s\n", err) return } for rows.Next() { var zcode int64 var data []byte if err = rows.Scan(&zcode, &data); err != nil { log.Printf("Error in scanning row: %s\n", err) rows.Close() return } // log.Printf("zcode: %d\n", zcode) c := common.InterleavedToCoord(zcode) if cub.Contains(c) { var encodedKey []byte if encodedKey, err = common.EncodeStringToBytes(common.CoordToPlain(c)); err != nil { log.Printf("Key encoding failed: %s\n", err) rows.Close() return } select { case blocks <- Block{Key: encodedKey, Data: data}: case <-done: rows.Close() return } } else { // Left the cuboid // log.Printf("Left cuboid %d\n", zcode) rows.Close() zmin = common.BigMin(zmin, zmax, zcode) goto loop } } if err = rows.Err(); err != nil { log.Printf("Error in range query: %s\n", err) } rows.Close() }() return } type duffStmt struct { stmt *sql.Stmt counter int params [8]int64 } func (ds *duffStmt) push(a, b int64) bool { ds.params[ds.counter] = a ds.counter++ ds.params[ds.counter] = b ds.counter++ return ds.counter > 7 } func (ds *duffStmt) Query() (*sql.Rows, error) { c := ds.counter ds.counter = 0 switch c { case 8: return ds.stmt.Query( ds.params[0], ds.params[1], ds.params[2], ds.params[3], ds.params[4], ds.params[5], ds.params[0], ds.params[1]) case 6: return ds.stmt.Query( ds.params[0], ds.params[1], ds.params[2], ds.params[3], ds.params[4], ds.params[5], ds.params[0], ds.params[1]) case 4: return ds.stmt.Query( ds.params[0], ds.params[1], ds.params[2], ds.params[3], ds.params[0], ds.params[1], ds.params[0], ds.params[1]) case 2: return ds.stmt.Query( ds.params[0], ds.params[1], ds.params[0], ds.params[1], ds.params[0], ds.params[1], ds.params[0], ds.params[1]) } return nil, nil } func (ss *SqliteSession) plainSpatialQuery(first, second []byte, done chan struct{}) (blocks chan Block, err error) { var ( firstKey int64 secondKey int64 ) if firstKey, err = common.DecodeStringFromBytes(first); err != nil { return } if secondKey, err = common.DecodeStringFromBytes(second); err != nil { return } c1 := common.PlainToCoord(firstKey) c2 := common.PlainToCoord(secondKey) c1, c2 = common.MinCoord(c1, c2), common.MaxCoord(c1, c2) blocks = make(chan Block) globalLock.RLock() go func() { defer globalLock.RUnlock() defer close(blocks) rangeStmt := duffStmt{stmt: ss.txStmt(ss.backend.rangeStmt)} send := func(rows *sql.Rows, err error) bool { if rows == nil { return true } if err != nil { log.Printf("Error in range query: %s\n", err) return false } defer rows.Close() for rows.Next() { var key int64 var data []byte if err = rows.Scan(&key, &data); err != nil { log.Printf("Error in scanning row: %s\n", err) return false } var encodedKey []byte if encodedKey, err = common.EncodeStringToBytes(key); err != nil { log.Printf("Key encoding failed: %s\n", err) return false } select { case blocks <- Block{Key: encodedKey, Data: data}: case <-done: return false } } if err = rows.Err(); err != nil { log.Printf("Error in range query: %s\n", err) return false } return true } if ss.backend.coverage == nil { a, b := common.Coord{X: c1.X}, common.Coord{X: c2.X} for a.Y = c2.Y; a.Y >= c1.Y; a.Y-- { b.Y = a.Y for a.Z = c1.Z; a.Z <= c2.Z; a.Z++ { b.Z = a.Z // Ordering should not be necessary. from, to := order(common.CoordToPlain(a), common.CoordToPlain(b)) if rangeStmt.push(from, to) && !send(rangeStmt.Query()) { return } } } } else { var a, b common.Coord for _, r := range ss.backend.coverage.Query(c1, c2) { a.Z, b.Z = int16(r.Z), int16(r.Z) a.X, b.X = int16(r.X1), int16(r.X2) // log.Printf("y1 y2 x1 x2 z: %d %d, %d %d, %d\n", r.Y1, r.Y2, r.X1, r.X2, r.Z) for y := r.Y2; y >= r.Y1; y-- { a.Y, b.Y = int16(y), int16(y) from, to := order(common.CoordToPlain(a), common.CoordToPlain(b)) if rangeStmt.push(from, to) && !send(rangeStmt.Query()) { return } } } } send(rangeStmt.Query()) }() return }