package main import ( "bufio" "bytes" "flag" "fmt" "log" "net" "os" "os/signal" "strconv" leveldb "github.com/jmhodges/levigo" ) type itemConsumer interface { consumeInteger(int64) consumeArray(int64) consumeSimpleString(string) consumeBulkString([]byte) consumeError(error) } type lexer struct { reader *bufio.Reader consumer itemConsumer } func newLexer(reader *bufio.Reader, consumer itemConsumer) *lexer { return &lexer{reader, consumer} } func (l *lexer) nextLine() []byte { line, err := l.reader.ReadBytes('\n') if err != nil { l.consumer.consumeError(err) return nil } line = bytes.TrimRight(line, "\r\n") log.Printf("%q", line) return line } func (l *lexer) dispatch(line []byte) bool { if len(line) < 1 { return false } switch line[0] { case '-': return true // ignore errors case ':': return l.integer(line) case '+': return l.simpleString(line) case '$': return l.bulkString(line) case '*': return l.array(line) } return true } func (l *lexer) simpleString(line []byte) bool { l.consumer.consumeSimpleString(string(line[1:])) return true } func (l *lexer) integer(line []byte) bool { i, err := strconv.ParseInt(string(line[1:]), 10, 64) if err != nil { l.consumer.consumeError(err) return false } l.consumer.consumeInteger(i) return true } func (l *lexer) bulkString(line []byte) bool { var i int64 var err error i, err = strconv.ParseInt(string(line[1:]), 10, 64) if err != nil { l.consumer.consumeError(err) return false } switch { case i < 0: l.consumer.consumeBulkString(nil) case i == 0: l.consumer.consumeBulkString([]byte{}) default: data := make([]byte, i, i) var n int if n, err = l.reader.Read(data); err != nil { l.consumer.consumeError(err) return false } if _, err = l.reader.ReadBytes('\n'); err != nil { l.consumer.consumeError(err) return false } l.consumer.consumeBulkString(data[0:n]) } return true } func (l *lexer) array(line []byte) bool { var i int64 var err error i, err = strconv.ParseInt(string(line[1:]), 10, 64) if err != nil { l.consumer.consumeError(err) return false } l.consumer.consumeArray(i) return true } func (l *lexer) lex() { for line := l.nextLine(); line != nil && l.dispatch(line); line = l.nextLine() { } } type commandExecutor interface { get([]byte) set(key, block []byte) startTransaction() commitTransaction() } type commandConsumer struct { exec commandExecutor missing int64 elements []interface{} } func newCommandConsumer(exec commandExecutor) *commandConsumer { return &commandConsumer{ exec, 0, []interface{}{}} } func (cc *commandConsumer) push(i interface{}) { cc.elements = append(cc.elements, i) cc.missing-- if cc.missing <= 0 { cc.missing = 0 cc.execute() cc.elements = []interface{}{} } } func asString(i interface{}) string { switch i.(type) { case string: return i.(string) case []byte: return string(i.([]byte)) } return fmt.Sprintf("%s", i) } func (cc *commandConsumer) execute() { l := len(cc.elements) log.Printf("command length: %d", l) if l < 1 { log.Printf("WARN: Too less argument for command") return } cmd := asString(cc.elements[0]) switch cmd { case "HGET": if l < 3 { log.Println("WARN: Missing argment for HGET") return } // ignore hash if block, ok := cc.elements[2].([]byte); ok { log.Printf("HGET %d", cc.elements[2]) cc.exec.get(block) } else { log.Println("WARN HGET data is not a byte slice") } case "HSET": if l < 4 { log.Println("WARN: Missing argment for HSET") return } // ignore hash key, ok1 := cc.elements[2].([]byte) block, ok2 := cc.elements[3].([]byte) if !ok1 || !ok2 { log.Printf("WARM HSET key or data is not a byte slice") return } cc.exec.set(key, block) case "MULTI": cc.exec.startTransaction() case "EXEC": cc.exec.commitTransaction() default: log.Printf("UNKOWN command: %s", cmd) } } func (cc *commandConsumer) consumeSimpleString(s string) { log.Printf("simple string: %s", s) cc.push(s) } func shorten(data []byte) string { if len(data) > 10 { return fmt.Sprintf("%.10q...", data) } return fmt.Sprintf("%q", data) } func (cc *commandConsumer) consumeBulkString(data []byte) { s := shorten(data) log.Printf("buld string: len = %d: %s", len(data), s) cc.push(data) } func (cc *commandConsumer) consumeInteger(i int64) { log.Printf("integer: %d", i) cc.push(i) } func (cc *commandConsumer) consumeError(err error) { log.Printf("error: %s", err) } func (cc *commandConsumer) consumeArray(i int64) { log.Printf("array: %d", i) if cc.missing > 0 { log.Println("WARN: Nested arrays are not supported!") return } if i < 0 { log.Println("Null arrays are not supported") return } cc.missing = i } type connection struct { conn net.Conn db *leveldb.DB tx *leveldb.WriteBatch intArray []int } func newConnection(conn net.Conn, db *leveldb.DB) *connection { return &connection{ conn, db, nil, []int{}} } func (c *connection) run() { defer c.conn.Close() cc := newCommandConsumer(c) r := bufio.NewReader(c.conn) lexer := newLexer(r, cc) lexer.lex() } func (c *connection) get(key []byte) { log.Printf("client requested block: %q", key) var err error var data []byte ro := leveldb.NewReadOptions() defer ro.Close() if data, err = c.db.Get(ro, key); err != nil { log.Printf("Something is wrong with db: %s", err) if err = c.writeError(); err != nil { log.Printf("Send message to client failed: %s", err) } } else { if err = c.writeBlock(data); err != nil { log.Printf("Send message to client failed: %s", err) } } } func (c *connection) set(key, block []byte) { log.Printf("client wants to store block: %q", key) var err error var exists int if exists, err = c.keyExists(key); err != nil { log.Printf("Something is wrong with db: %s", err) if err = c.writeError(); err != nil { log.Printf("Writing message to client failed: %s", err) } return } if c.tx != nil { c.tx.Put(key, block) c.intArray = append(c.intArray, exists) if err = c.writeQueued(); err != nil { log.Printf("Writing message to client failed: %s", err) } return } else { wo := leveldb.NewWriteOptions() defer wo.Close() if err = c.db.Put(wo, key, block); err != nil { log.Printf("Something is wrong with db: %s", err) if err = c.writeError(); err != nil { log.Printf("Writing message to client failed: %s", err) } return } if err = c.writeInteger(exists); err != nil { log.Printf("Writing message to client failed: %s", err) } } } func (c *connection) keyExists(key []byte) (exists int, err error) { ro := leveldb.NewReadOptions() defer ro.Close() var data []byte if data, err = c.db.Get(ro, key); err != nil { return } if data != nil { exists = 1 } else { exists = 0 } return } func (c *connection) startTransaction() { if c.tx != nil { log.Println("WARN: Already running transaction.") } else { c.tx = leveldb.NewWriteBatch() } if err := c.writeOk(); err != nil { log.Printf("Writing message to client failed: %s", err) } } func (c *connection) commitTransaction() { var err error if c.tx == nil { if err = c.writeEmptyArray(); err != nil { log.Printf("Writing message to client failed: %s", err) } return } tx := c.tx c.tx = nil defer tx.Close() arr := c.intArray c.intArray = []int{} wo := leveldb.NewWriteOptions() defer wo.Close() if err = c.db.Write(wo, tx); err != nil { log.Printf("Something went wrong in writing transaction: %s", err) if err = c.writeError(); err != nil { log.Printf("Writing message to client failed: %s", err) } return } if err = c.writeIntegerArray(arr); err != nil { log.Printf("Writing message to client failed: %s", err) } } var ( redisOk = []byte("+OK\r\n") redisDbError = []byte("-FAIL\r\n") redisNoSuchBlock = []byte("$-1\r\n") redisCrnl = []byte("\r\n") redisEmptyArray = []byte("*0\r\n") redisQueued = []byte("+QUEUED\r\n") ) func (c *connection) writeError() (err error) { _, err = c.conn.Write(redisDbError) return } func (c *connection) writeEmptyArray() (err error) { _, err = c.conn.Write(redisDbError) return } func (c *connection) writeInteger(v int) (err error) { _, err = c.conn.Write([]byte(fmt.Sprintf(":%d\r\n", v))) return } func (c *connection) writeIntegerArray(arr []int) (err error) { if _, err = c.conn.Write([]byte(fmt.Sprintf("*%d\r\n", len(arr)))); err != nil { return } for x := range arr { if err = c.writeInteger(x); err != nil { return } } return } func (c *connection) writeOk() (err error) { _, err = c.conn.Write(redisOk) return } func (c *connection) writeQueued() (err error) { _, err = c.conn.Write(redisQueued) return } func (c *connection) writeBlock(data []byte) (err error) { con := c.conn if data == nil { _, err = con.Write(redisNoSuchBlock) } else { if _, err = con.Write([]byte(fmt.Sprintf("$%d\r\n", len(data)))); err != nil { return } if _, err = con.Write(data); err != nil { return } _, err = con.Write(redisCrnl) } return } func main() { var port int var host string var cacheSize int flag.IntVar(&port, "port", 6379, "port to bind") flag.StringVar(&host, "host", "", "host to bind") flag.IntVar(&cacheSize, "cache", 32, "cache size in MB") flag.Parse() args := flag.Args() if len(args) < 1 { log.Fatal("Missing path to world") } cache := leveldb.NewLRUCache(cacheSize * 1024 * 1024) defer cache.Close() opts := leveldb.NewOptions() opts.SetCache(cache) opts.SetCreateIfMissing(true) var err error var db *leveldb.DB if db, err = leveldb.Open(args[0], opts); err != nil { log.Fatal(err) } defer db.Close() var listener net.Listener listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", host, port)) if err != nil { log.Fatal(err) } defer listener.Close() connChan := make(chan net.Conn) defer close(connChan) sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, os.Interrupt, os.Kill) go func() { log.Println("Server started") for { conn, err := listener.Accept() if err != nil { log.Fatal(err) } log.Printf("Client accepted: %s", conn) connChan <- conn } }() for { select { case conn := <-connChan: log.Printf("New connection %s", conn) go newConnection(conn, db).run() case <-sigChan: log.Println("Shutting down") return } } }