diff --git a/backend.go b/backend.go index 498fb56..783cd8f 100644 --- a/backend.go +++ b/backend.go @@ -4,11 +4,16 @@ package main -type Backend interface { +type Session interface { Fetch(hash, key []byte) ([]byte, error) InTransaction() bool Store(hash, key, value []byte) (bool, error) BeginTransaction() error CommitTransaction() error + Close() error +} + +type Backend interface { + NewSession() (Session, error) Shutdown() error } diff --git a/connection.go b/connection.go index 9283b8c..b62f426 100644 --- a/connection.go +++ b/connection.go @@ -22,19 +22,22 @@ var ( type Connection struct { conn net.Conn - backend Backend + session Session boolArray []bool } -func NewConnection(conn net.Conn, backend Backend) *Connection { +func NewConnection(conn net.Conn, session Session) *Connection { return &Connection{ conn: conn, - backend: backend, + session: session, boolArray: []bool{}} } func (c *Connection) Run() { - defer c.conn.Close() + defer func() { + c.session.Close() + c.conn.Close() + }() rce := NewRedisCommandExecutor(c) r := bufio.NewReader(c.conn) parser := NewRedisParser(r, rce) @@ -50,7 +53,7 @@ func (c *Connection) Hget(hash, key []byte) bool { var err error var data []byte - if data, err = c.backend.Fetch(hash, key); err != nil { + if data, err = c.session.Fetch(hash, key); err != nil { return c.writeError(err) } @@ -61,11 +64,11 @@ func (c *Connection) Hset(hash, key, data []byte) bool { var err error var exists bool - if exists, err = c.backend.Store(hash, key, data); err != nil { + if exists, err = c.session.Store(hash, key, data); err != nil { return c.writeError(err) } - if c.backend.InTransaction() { + if c.session.InTransaction() { c.boolArray = append(c.boolArray, exists) return c.writeQueued() } @@ -74,10 +77,10 @@ func (c *Connection) Hset(hash, key, data []byte) bool { } func (c *Connection) Multi() bool { - if c.backend.InTransaction() { + if c.session.InTransaction() { log.Println("WARN: Already running transaction.") } else { - if err := c.backend.BeginTransaction(); err != nil { + if err := c.session.BeginTransaction(); err != nil { return c.writeError(err) } } @@ -85,12 +88,12 @@ func (c *Connection) Multi() bool { } func (c *Connection) Exec() bool { - if !c.backend.InTransaction() { + if !c.session.InTransaction() { return c.writeEmptyArray() } arr := c.boolArray c.boolArray = []bool{} - if err := c.backend.CommitTransaction(); err != nil { + if err := c.session.CommitTransaction(); err != nil { return c.writeError(err) } return c.writeBoolArray(arr) diff --git a/leveldb.go b/leveldb.go index 9a3513e..b245e41 100644 --- a/leveldb.go +++ b/leveldb.go @@ -5,13 +5,19 @@ package main import ( + "log" + leveldb "github.com/jmhodges/levigo" ) type LevelDBBackend struct { cache *leveldb.Cache db *leveldb.DB - tx *leveldb.WriteBatch +} + +type LevelDBSession struct { + backend *LevelDBBackend + tx *leveldb.WriteBatch } func NewLeveDBBackend(path string, cacheSize int) (ldb *LevelDBBackend, err error) { @@ -30,40 +36,46 @@ func NewLeveDBBackend(path string, cacheSize int) (ldb *LevelDBBackend, err erro return } -func (ldb *LevelDBBackend) Shutdown() error { - tx := ldb.tx - if tx != nil { - ldb.tx = nil - tx.Close() +func (ldb *LevelDBBackend) NewSession() (Session, error) { + return &LevelDBSession{ldb, nil}, nil +} + +func (ldbs *LevelDBSession) Close() error { + if ldbs.tx != nil { + ldbs.tx.Close() } + return nil +} + +func (ldb *LevelDBBackend) Shutdown() error { ldb.db.Close() ldb.cache.Close() return nil } -func (ldb *LevelDBBackend) Fetch(hash, key []byte) (value []byte, err error) { +func (ldbs *LevelDBSession) Fetch(hash, key []byte) (value []byte, err error) { ro := leveldb.NewReadOptions() - value, err = ldb.db.Get(ro, key) + value, err = ldbs.backend.db.Get(ro, key) ro.Close() return } -func (ldb *LevelDBBackend) InTransaction() bool { - return ldb.tx != nil +func (ldbs *LevelDBSession) InTransaction() bool { + return ldbs.tx != nil } -func (ldb *LevelDBBackend) keyExists(key []byte) (exists bool, err error) { +func (ldbs *LevelDBSession) keyExists(key []byte) (exists bool, err error) { ro := leveldb.NewReadOptions() defer ro.Close() var data []byte - if data, err = ldb.db.Get(ro, key); err != nil { + if data, err = ldbs.backend.db.Get(ro, key); err != nil { return } exists = data != nil return } -func (ldb *LevelDBBackend) Store(hash, key, value []byte) (exists bool, err error) { +func (ldbs *LevelDBSession) Store(hash, key, value []byte) (exists bool, err error) { var pos int64 if pos, err = bytes2pos(key); err != nil { @@ -72,32 +84,36 @@ func (ldb *LevelDBBackend) Store(hash, key, value []byte) (exists bool, err erro // Re-code it to make LevelDB happy. key = pos2bytes(pos) - if exists, err = ldb.keyExists(key); err != nil { + if exists, err = ldbs.keyExists(key); err != nil { return } - if ldb.tx != nil { - ldb.tx.Put(key, value) + if ldbs.tx != nil { + ldbs.tx.Put(key, value) return } wo := leveldb.NewWriteOptions() - err = ldb.db.Put(wo, key, value) + err = ldbs.backend.db.Put(wo, key, value) wo.Close() return } -func (ldb *LevelDBBackend) BeginTransaction() error { - ldb.tx = leveldb.NewWriteBatch() +func (ldbs *LevelDBSession) BeginTransaction() error { + ldbs.tx = leveldb.NewWriteBatch() return nil } -func (ldb *LevelDBBackend) CommitTransaction() (err error) { - tx := ldb.tx - ldb.tx = nil +func (ldbs *LevelDBSession) CommitTransaction() (err error) { + tx := ldbs.tx + if tx == nil { + log.Println("WARN: No transaction running.") + return + } + ldbs.tx = nil wo := leveldb.NewWriteOptions() - err = ldb.db.Write(wo, tx) + err = ldbs.backend.db.Write(wo, tx) wo.Close() return } diff --git a/main.go b/main.go index 9349d58..17928f8 100644 --- a/main.go +++ b/main.go @@ -74,7 +74,13 @@ func main() { for { select { case conn := <-connChan: - go NewConnection(conn, backend).Run() + var session Session + if session, err = backend.NewSession(); err != nil { + log.Printf("Cannot create session: %s", err) + conn.Close() + } else { + go NewConnection(conn, session).Run() + } case <-sigChan: log.Println("Shutting down") return diff --git a/sqlite.go b/sqlite.go index 90a826e..03361ba 100644 --- a/sqlite.go +++ b/sqlite.go @@ -23,13 +23,30 @@ const ( type SqliteBackend struct { db *sql.DB - tx *sql.Tx existsStmt *sql.Stmt fetchStmt *sql.Stmt insertStmt *sql.Stmt updateStmt *sql.Stmt } +type SqliteSession struct { + backend *SqliteBackend + tx *sql.Tx +} + +func (ss *SqliteBackend) NewSession() (Session, error) { + return &SqliteSession{ss, nil}, nil +} + +func (ss *SqliteSession) Close() error { + t := ss.tx + if t != nil { + ss.tx = nil + return t.Rollback() + } + return nil +} + func NewSqliteBackend(path string) (sqlb *SqliteBackend, err error) { res := SqliteBackend{} @@ -62,15 +79,6 @@ func NewSqliteBackend(path string) (sqlb *SqliteBackend, err error) { return } -func rollbackTx(tx **sql.Tx) error { - t := *tx - if t != nil { - *tx = nil - return t.Rollback() - } - return nil -} - func closeStmt(stmt **sql.Stmt) error { s := *stmt if s != nil { @@ -90,7 +98,6 @@ func closeDB(db **sql.DB) error { } func (sqlb *SqliteBackend) closeAll() error { - rollbackTx(&sqlb.tx) closeStmt(&sqlb.fetchStmt) closeStmt(&sqlb.insertStmt) closeStmt(&sqlb.updateStmt) @@ -105,14 +112,14 @@ func (sqlb *SqliteBackend) Shutdown() error { return sqlb.closeAll() } -func (sqlb *SqliteBackend) txStmt(stmt *sql.Stmt) *sql.Stmt { - if sqlb.tx != nil { - return sqlb.tx.Stmt(stmt) +func (ss *SqliteSession) txStmt(stmt *sql.Stmt) *sql.Stmt { + if ss.tx != nil { + return ss.tx.Stmt(stmt) } return stmt } -func (sqlb *SqliteBackend) Fetch(hash, key []byte) (data []byte, err error) { +func (ss *SqliteSession) Fetch(hash, key []byte) (data []byte, err error) { var pos int64 if pos, err = bytes2pos(key); err != nil { return @@ -121,7 +128,7 @@ func (sqlb *SqliteBackend) Fetch(hash, key []byte) (data []byte, err error) { globalLock.RLock() defer globalLock.RUnlock() - fetchStmt := sqlb.txStmt(sqlb.fetchStmt) + fetchStmt := ss.txStmt(ss.backend.fetchStmt) err2 := fetchStmt.QueryRow(pos).Scan(&data) if err2 == sql.ErrNoRows { return @@ -130,11 +137,11 @@ func (sqlb *SqliteBackend) Fetch(hash, key []byte) (data []byte, err error) { return } -func (sqlb *SqliteBackend) InTransaction() bool { - return sqlb.tx != nil +func (ss *SqliteSession) InTransaction() bool { + return ss.tx != nil } -func (sqlb *SqliteBackend) Store(hash, key, value []byte) (exists bool, err error) { +func (ss *SqliteSession) Store(hash, key, value []byte) (exists bool, err error) { var pos int64 if pos, err = bytes2pos(key); err != nil { return @@ -143,7 +150,7 @@ func (sqlb *SqliteBackend) Store(hash, key, value []byte) (exists bool, err erro globalLock.Lock() defer globalLock.Unlock() - existsStmt := sqlb.txStmt(sqlb.existsStmt) + existsStmt := ss.txStmt(ss.backend.existsStmt) var x int err2 := existsStmt.QueryRow(pos).Scan(&x) @@ -157,30 +164,30 @@ func (sqlb *SqliteBackend) Store(hash, key, value []byte) (exists bool, err erro } if exists { - updateStmt := sqlb.txStmt(sqlb.updateStmt) + updateStmt := ss.txStmt(ss.backend.updateStmt) _, err = updateStmt.Exec(value, pos) } else { - insertStmt := sqlb.txStmt(sqlb.insertStmt) + insertStmt := ss.txStmt(ss.backend.insertStmt) _, err = insertStmt.Exec(pos, value) } return } -func (sqlb *SqliteBackend) BeginTransaction() (err error) { - if sqlb.tx != nil { +func (ss *SqliteSession) BeginTransaction() (err error) { + if ss.tx != nil { log.Println("WARN: Already running transaction.") return nil } globalLock.Lock() defer globalLock.Unlock() - sqlb.tx, err = sqlb.db.Begin() + ss.tx, err = ss.backend.db.Begin() return } -func (sqlb *SqliteBackend) CommitTransaction() error { +func (ss *SqliteSession) CommitTransaction() error { - tx := sqlb.tx + tx := ss.tx if tx == nil { log.Println("WARN: No transaction running.") return nil @@ -188,6 +195,6 @@ func (sqlb *SqliteBackend) CommitTransaction() error { globalLock.Lock() defer globalLock.Unlock() - sqlb.tx = nil + ss.tx = nil return tx.Commit() }