// Copyright 2014 by Sascha L. Teichmann // Use of this source code is governed by the MIT license // that can be found in the LICENSE file. package main import ( "database/sql" "log" "strconv" "sync" _ "github.com/mattn/go-sqlite3" ) var globalLock sync.RWMutex const ( fetchSql = "SELECT data FROM blocks WHERE pos = ?" existsSql = "SELECT 1 FROM blocks WHERE pos = ?" updateSql = "UPDATE blocks SET data = ? WHERE pos = ?" insertSql = "INSERT INTO blocks (pos, data) VALUES (?, ?)" ) type SqliteBackend struct { db *sql.DB tx *sql.Tx existsStmt *sql.Stmt fetchStmt *sql.Stmt insertStmt *sql.Stmt updateStmt *sql.Stmt } func NewSqliteBackend(path string) (sqlb *SqliteBackend, err error) { res := SqliteBackend{} if res.db, err = sql.Open("sqlite3", path); err != nil { return } if res.existsStmt, err = res.db.Prepare(existsSql); err != nil { res.closeAll() return } if res.fetchStmt, err = res.db.Prepare(fetchSql); err != nil { res.closeAll() return } if res.insertStmt, err = res.db.Prepare(insertSql); err != nil { res.closeAll() return } if res.updateStmt, err = res.db.Prepare(updateSql); err != nil { res.closeAll() return } sqlb = &res return } func rollbackTx(tx **sql.Tx) error { t := *tx if t != nil { *tx = nil return t.Rollback() } return nil } func closeStmt(stmt **sql.Stmt) error { s := *stmt if s != nil { *stmt = nil return s.Close() } return nil } func closeDB(db **sql.DB) error { d := *db if d != nil { *db = nil return d.Close() } return nil } func (sqlb *SqliteBackend) closeAll() error { rollbackTx(&sqlb.tx) closeStmt(&sqlb.fetchStmt) closeStmt(&sqlb.insertStmt) closeStmt(&sqlb.updateStmt) closeStmt(&sqlb.existsStmt) return closeDB(&sqlb.db) } func (sqlb *SqliteBackend) Shutdown() error { globalLock.Lock() defer globalLock.Unlock() return sqlb.closeAll() } func (sqlb *SqliteBackend) txStmt(stmt *sql.Stmt) *sql.Stmt { if sqlb.tx != nil { return sqlb.tx.Stmt(stmt) } return stmt } func bytes2pos(key []byte) (pos int64, err error) { return strconv.ParseInt(string(key), 10, 64) } func (sqlb *SqliteBackend) Fetch(hash, key []byte) (data []byte, err error) { var pos int64 if pos, err = bytes2pos(key); err != nil { return } globalLock.RLock() defer globalLock.RUnlock() fetchStmt := sqlb.txStmt(sqlb.fetchStmt) err2 := fetchStmt.QueryRow(pos).Scan(&data) if err2 == sql.ErrNoRows { return } err = err2 return } func (sqlb *SqliteBackend) InTransaction() bool { return sqlb.tx != nil } func (sqlb *SqliteBackend) Store(hash, key, value []byte) (exists bool, err error) { var pos int64 if pos, err = bytes2pos(key); err != nil { return } globalLock.Lock() defer globalLock.Unlock() existsStmt := sqlb.txStmt(sqlb.existsStmt) var x int err2 := existsStmt.QueryRow(pos).Scan(&x) if err2 == sql.ErrNoRows { exists = false } else if err2 != nil { err = err2 return } else { exists = true } if exists { updateStmt := sqlb.txStmt(sqlb.updateStmt) _, err = updateStmt.Exec(value, pos) } else { insertStmt := sqlb.txStmt(sqlb.insertStmt) _, err = insertStmt.Exec(pos, value) } return } func (sqlb *SqliteBackend) BeginTransaction() (err error) { if sqlb.tx != nil { log.Println("WARN: Already running transaction.") return nil } globalLock.Lock() defer globalLock.Unlock() sqlb.tx, err = sqlb.db.Begin() return } func (sqlb *SqliteBackend) CommitTransaction() error { tx := sqlb.tx if tx == nil { log.Println("WARN: No transaction running.") return nil } globalLock.Lock() defer globalLock.Unlock() sqlb.tx = nil return tx.Commit() }