Introduced a backend session to track transactions more clearly.

This commit is contained in:
Sascha L. Teichmann 2014-08-06 01:11:41 +02:00
parent 737b09436c
commit 3f97399a82
5 changed files with 101 additions and 64 deletions

View File

@ -4,11 +4,16 @@
package main package main
type Backend interface { type Session interface {
Fetch(hash, key []byte) ([]byte, error) Fetch(hash, key []byte) ([]byte, error)
InTransaction() bool InTransaction() bool
Store(hash, key, value []byte) (bool, error) Store(hash, key, value []byte) (bool, error)
BeginTransaction() error BeginTransaction() error
CommitTransaction() error CommitTransaction() error
Close() error
}
type Backend interface {
NewSession() (Session, error)
Shutdown() error Shutdown() error
} }

View File

@ -22,19 +22,22 @@ var (
type Connection struct { type Connection struct {
conn net.Conn conn net.Conn
backend Backend session Session
boolArray []bool boolArray []bool
} }
func NewConnection(conn net.Conn, backend Backend) *Connection { func NewConnection(conn net.Conn, session Session) *Connection {
return &Connection{ return &Connection{
conn: conn, conn: conn,
backend: backend, session: session,
boolArray: []bool{}} boolArray: []bool{}}
} }
func (c *Connection) Run() { func (c *Connection) Run() {
defer c.conn.Close() defer func() {
c.session.Close()
c.conn.Close()
}()
rce := NewRedisCommandExecutor(c) rce := NewRedisCommandExecutor(c)
r := bufio.NewReader(c.conn) r := bufio.NewReader(c.conn)
parser := NewRedisParser(r, rce) parser := NewRedisParser(r, rce)
@ -50,7 +53,7 @@ func (c *Connection) Hget(hash, key []byte) bool {
var err error var err error
var data []byte var data []byte
if data, err = c.backend.Fetch(hash, key); err != nil { if data, err = c.session.Fetch(hash, key); err != nil {
return c.writeError(err) return c.writeError(err)
} }
@ -61,11 +64,11 @@ func (c *Connection) Hset(hash, key, data []byte) bool {
var err error var err error
var exists bool var exists bool
if exists, err = c.backend.Store(hash, key, data); err != nil { if exists, err = c.session.Store(hash, key, data); err != nil {
return c.writeError(err) return c.writeError(err)
} }
if c.backend.InTransaction() { if c.session.InTransaction() {
c.boolArray = append(c.boolArray, exists) c.boolArray = append(c.boolArray, exists)
return c.writeQueued() return c.writeQueued()
} }
@ -74,10 +77,10 @@ func (c *Connection) Hset(hash, key, data []byte) bool {
} }
func (c *Connection) Multi() bool { func (c *Connection) Multi() bool {
if c.backend.InTransaction() { if c.session.InTransaction() {
log.Println("WARN: Already running transaction.") log.Println("WARN: Already running transaction.")
} else { } else {
if err := c.backend.BeginTransaction(); err != nil { if err := c.session.BeginTransaction(); err != nil {
return c.writeError(err) return c.writeError(err)
} }
} }
@ -85,12 +88,12 @@ func (c *Connection) Multi() bool {
} }
func (c *Connection) Exec() bool { func (c *Connection) Exec() bool {
if !c.backend.InTransaction() { if !c.session.InTransaction() {
return c.writeEmptyArray() return c.writeEmptyArray()
} }
arr := c.boolArray arr := c.boolArray
c.boolArray = []bool{} c.boolArray = []bool{}
if err := c.backend.CommitTransaction(); err != nil { if err := c.session.CommitTransaction(); err != nil {
return c.writeError(err) return c.writeError(err)
} }
return c.writeBoolArray(arr) return c.writeBoolArray(arr)

View File

@ -5,12 +5,18 @@
package main package main
import ( import (
"log"
leveldb "github.com/jmhodges/levigo" leveldb "github.com/jmhodges/levigo"
) )
type LevelDBBackend struct { type LevelDBBackend struct {
cache *leveldb.Cache cache *leveldb.Cache
db *leveldb.DB db *leveldb.DB
}
type LevelDBSession struct {
backend *LevelDBBackend
tx *leveldb.WriteBatch tx *leveldb.WriteBatch
} }
@ -30,40 +36,46 @@ func NewLeveDBBackend(path string, cacheSize int) (ldb *LevelDBBackend, err erro
return return
} }
func (ldb *LevelDBBackend) Shutdown() error { func (ldb *LevelDBBackend) NewSession() (Session, error) {
tx := ldb.tx return &LevelDBSession{ldb, nil}, nil
if tx != nil { }
ldb.tx = nil
tx.Close() func (ldbs *LevelDBSession) Close() error {
if ldbs.tx != nil {
ldbs.tx.Close()
} }
return nil
}
func (ldb *LevelDBBackend) Shutdown() error {
ldb.db.Close() ldb.db.Close()
ldb.cache.Close() ldb.cache.Close()
return nil return nil
} }
func (ldb *LevelDBBackend) Fetch(hash, key []byte) (value []byte, err error) { func (ldbs *LevelDBSession) Fetch(hash, key []byte) (value []byte, err error) {
ro := leveldb.NewReadOptions() ro := leveldb.NewReadOptions()
value, err = ldb.db.Get(ro, key) value, err = ldbs.backend.db.Get(ro, key)
ro.Close() ro.Close()
return return
} }
func (ldb *LevelDBBackend) InTransaction() bool { func (ldbs *LevelDBSession) InTransaction() bool {
return ldb.tx != nil return ldbs.tx != nil
} }
func (ldb *LevelDBBackend) keyExists(key []byte) (exists bool, err error) { func (ldbs *LevelDBSession) keyExists(key []byte) (exists bool, err error) {
ro := leveldb.NewReadOptions() ro := leveldb.NewReadOptions()
defer ro.Close() defer ro.Close()
var data []byte var data []byte
if data, err = ldb.db.Get(ro, key); err != nil { if data, err = ldbs.backend.db.Get(ro, key); err != nil {
return return
} }
exists = data != nil exists = data != nil
return return
} }
func (ldb *LevelDBBackend) Store(hash, key, value []byte) (exists bool, err error) { func (ldbs *LevelDBSession) Store(hash, key, value []byte) (exists bool, err error) {
var pos int64 var pos int64
if pos, err = bytes2pos(key); err != nil { if pos, err = bytes2pos(key); err != nil {
@ -72,32 +84,36 @@ func (ldb *LevelDBBackend) Store(hash, key, value []byte) (exists bool, err erro
// Re-code it to make LevelDB happy. // Re-code it to make LevelDB happy.
key = pos2bytes(pos) key = pos2bytes(pos)
if exists, err = ldb.keyExists(key); err != nil { if exists, err = ldbs.keyExists(key); err != nil {
return return
} }
if ldb.tx != nil { if ldbs.tx != nil {
ldb.tx.Put(key, value) ldbs.tx.Put(key, value)
return return
} }
wo := leveldb.NewWriteOptions() wo := leveldb.NewWriteOptions()
err = ldb.db.Put(wo, key, value) err = ldbs.backend.db.Put(wo, key, value)
wo.Close() wo.Close()
return return
} }
func (ldb *LevelDBBackend) BeginTransaction() error { func (ldbs *LevelDBSession) BeginTransaction() error {
ldb.tx = leveldb.NewWriteBatch() ldbs.tx = leveldb.NewWriteBatch()
return nil return nil
} }
func (ldb *LevelDBBackend) CommitTransaction() (err error) { func (ldbs *LevelDBSession) CommitTransaction() (err error) {
tx := ldb.tx tx := ldbs.tx
ldb.tx = nil if tx == nil {
log.Println("WARN: No transaction running.")
return
}
ldbs.tx = nil
wo := leveldb.NewWriteOptions() wo := leveldb.NewWriteOptions()
err = ldb.db.Write(wo, tx) err = ldbs.backend.db.Write(wo, tx)
wo.Close() wo.Close()
return return
} }

View File

@ -74,7 +74,13 @@ func main() {
for { for {
select { select {
case conn := <-connChan: case conn := <-connChan:
go NewConnection(conn, backend).Run() var session Session
if session, err = backend.NewSession(); err != nil {
log.Printf("Cannot create session: %s", err)
conn.Close()
} else {
go NewConnection(conn, session).Run()
}
case <-sigChan: case <-sigChan:
log.Println("Shutting down") log.Println("Shutting down")
return return

View File

@ -23,13 +23,30 @@ const (
type SqliteBackend struct { type SqliteBackend struct {
db *sql.DB db *sql.DB
tx *sql.Tx
existsStmt *sql.Stmt existsStmt *sql.Stmt
fetchStmt *sql.Stmt fetchStmt *sql.Stmt
insertStmt *sql.Stmt insertStmt *sql.Stmt
updateStmt *sql.Stmt updateStmt *sql.Stmt
} }
type SqliteSession struct {
backend *SqliteBackend
tx *sql.Tx
}
func (ss *SqliteBackend) NewSession() (Session, error) {
return &SqliteSession{ss, nil}, nil
}
func (ss *SqliteSession) Close() error {
t := ss.tx
if t != nil {
ss.tx = nil
return t.Rollback()
}
return nil
}
func NewSqliteBackend(path string) (sqlb *SqliteBackend, err error) { func NewSqliteBackend(path string) (sqlb *SqliteBackend, err error) {
res := SqliteBackend{} res := SqliteBackend{}
@ -62,15 +79,6 @@ func NewSqliteBackend(path string) (sqlb *SqliteBackend, err error) {
return return
} }
func rollbackTx(tx **sql.Tx) error {
t := *tx
if t != nil {
*tx = nil
return t.Rollback()
}
return nil
}
func closeStmt(stmt **sql.Stmt) error { func closeStmt(stmt **sql.Stmt) error {
s := *stmt s := *stmt
if s != nil { if s != nil {
@ -90,7 +98,6 @@ func closeDB(db **sql.DB) error {
} }
func (sqlb *SqliteBackend) closeAll() error { func (sqlb *SqliteBackend) closeAll() error {
rollbackTx(&sqlb.tx)
closeStmt(&sqlb.fetchStmt) closeStmt(&sqlb.fetchStmt)
closeStmt(&sqlb.insertStmt) closeStmt(&sqlb.insertStmt)
closeStmt(&sqlb.updateStmt) closeStmt(&sqlb.updateStmt)
@ -105,14 +112,14 @@ func (sqlb *SqliteBackend) Shutdown() error {
return sqlb.closeAll() return sqlb.closeAll()
} }
func (sqlb *SqliteBackend) txStmt(stmt *sql.Stmt) *sql.Stmt { func (ss *SqliteSession) txStmt(stmt *sql.Stmt) *sql.Stmt {
if sqlb.tx != nil { if ss.tx != nil {
return sqlb.tx.Stmt(stmt) return ss.tx.Stmt(stmt)
} }
return stmt return stmt
} }
func (sqlb *SqliteBackend) Fetch(hash, key []byte) (data []byte, err error) { func (ss *SqliteSession) Fetch(hash, key []byte) (data []byte, err error) {
var pos int64 var pos int64
if pos, err = bytes2pos(key); err != nil { if pos, err = bytes2pos(key); err != nil {
return return
@ -121,7 +128,7 @@ func (sqlb *SqliteBackend) Fetch(hash, key []byte) (data []byte, err error) {
globalLock.RLock() globalLock.RLock()
defer globalLock.RUnlock() defer globalLock.RUnlock()
fetchStmt := sqlb.txStmt(sqlb.fetchStmt) fetchStmt := ss.txStmt(ss.backend.fetchStmt)
err2 := fetchStmt.QueryRow(pos).Scan(&data) err2 := fetchStmt.QueryRow(pos).Scan(&data)
if err2 == sql.ErrNoRows { if err2 == sql.ErrNoRows {
return return
@ -130,11 +137,11 @@ func (sqlb *SqliteBackend) Fetch(hash, key []byte) (data []byte, err error) {
return return
} }
func (sqlb *SqliteBackend) InTransaction() bool { func (ss *SqliteSession) InTransaction() bool {
return sqlb.tx != nil return ss.tx != nil
} }
func (sqlb *SqliteBackend) Store(hash, key, value []byte) (exists bool, err error) { func (ss *SqliteSession) Store(hash, key, value []byte) (exists bool, err error) {
var pos int64 var pos int64
if pos, err = bytes2pos(key); err != nil { if pos, err = bytes2pos(key); err != nil {
return return
@ -143,7 +150,7 @@ func (sqlb *SqliteBackend) Store(hash, key, value []byte) (exists bool, err erro
globalLock.Lock() globalLock.Lock()
defer globalLock.Unlock() defer globalLock.Unlock()
existsStmt := sqlb.txStmt(sqlb.existsStmt) existsStmt := ss.txStmt(ss.backend.existsStmt)
var x int var x int
err2 := existsStmt.QueryRow(pos).Scan(&x) err2 := existsStmt.QueryRow(pos).Scan(&x)
@ -157,30 +164,30 @@ func (sqlb *SqliteBackend) Store(hash, key, value []byte) (exists bool, err erro
} }
if exists { if exists {
updateStmt := sqlb.txStmt(sqlb.updateStmt) updateStmt := ss.txStmt(ss.backend.updateStmt)
_, err = updateStmt.Exec(value, pos) _, err = updateStmt.Exec(value, pos)
} else { } else {
insertStmt := sqlb.txStmt(sqlb.insertStmt) insertStmt := ss.txStmt(ss.backend.insertStmt)
_, err = insertStmt.Exec(pos, value) _, err = insertStmt.Exec(pos, value)
} }
return return
} }
func (sqlb *SqliteBackend) BeginTransaction() (err error) { func (ss *SqliteSession) BeginTransaction() (err error) {
if sqlb.tx != nil { if ss.tx != nil {
log.Println("WARN: Already running transaction.") log.Println("WARN: Already running transaction.")
return nil return nil
} }
globalLock.Lock() globalLock.Lock()
defer globalLock.Unlock() defer globalLock.Unlock()
sqlb.tx, err = sqlb.db.Begin() ss.tx, err = ss.backend.db.Begin()
return return
} }
func (sqlb *SqliteBackend) CommitTransaction() error { func (ss *SqliteSession) CommitTransaction() error {
tx := sqlb.tx tx := ss.tx
if tx == nil { if tx == nil {
log.Println("WARN: No transaction running.") log.Println("WARN: No transaction running.")
return nil return nil
@ -188,6 +195,6 @@ func (sqlb *SqliteBackend) CommitTransaction() error {
globalLock.Lock() globalLock.Lock()
defer globalLock.Unlock() defer globalLock.Unlock()
sqlb.tx = nil ss.tx = nil
return tx.Commit() return tx.Commit()
} }