commit 1b36ba7c5bfcd884d50b0ac84e666ce3fa4b4fe4 Author: Sascha L. Teichmann Date: Sun Aug 3 00:01:14 2014 +0200 initial checkin diff --git a/main.go b/main.go new file mode 100644 index 0000000..7ee96e6 --- /dev/null +++ b/main.go @@ -0,0 +1,509 @@ +package main + +import ( + "bufio" + "bytes" + "flag" + "fmt" + "log" + "net" + "os" + "os/signal" + "strconv" + + leveldb "github.com/jmhodges/levigo" +) + +type itemConsumer interface { + consumeInteger(int64) + consumeArray(int64) + consumeSimpleString(string) + consumeBulkString([]byte) + consumeError(error) +} + +type lexer struct { + reader *bufio.Reader + consumer itemConsumer +} + +func newLexer(reader *bufio.Reader, consumer itemConsumer) *lexer { + return &lexer{reader, consumer} +} + +func (l *lexer) nextLine() []byte { + line, err := l.reader.ReadBytes('\n') + if err != nil { + l.consumer.consumeError(err) + return nil + } + line = bytes.TrimRight(line, "\r\n") + log.Printf("%q", line) + return line +} + +func (l *lexer) dispatch(line []byte) bool { + if len(line) < 1 { + return false + } + switch line[0] { + case '-': + return true // ignore errors + case ':': + return l.integer(line) + case '+': + return l.simpleString(line) + case '$': + return l.bulkString(line) + case '*': + return l.array(line) + } + return true +} + +func (l *lexer) simpleString(line []byte) bool { + l.consumer.consumeSimpleString(string(line[1:])) + return true +} + +func (l *lexer) integer(line []byte) bool { + i, err := strconv.ParseInt(string(line[1:]), 10, 64) + if err != nil { + l.consumer.consumeError(err) + return false + } + l.consumer.consumeInteger(i) + return true +} + +func (l *lexer) bulkString(line []byte) bool { + var i int64 + var err error + i, err = strconv.ParseInt(string(line[1:]), 10, 64) + if err != nil { + l.consumer.consumeError(err) + return false + } + switch { + case i < 0: + l.consumer.consumeBulkString(nil) + case i == 0: + l.consumer.consumeBulkString([]byte{}) + default: + data := make([]byte, i, i) + var n int + if n, err = l.reader.Read(data); err != nil { + l.consumer.consumeError(err) + return false + } + if _, err = l.reader.ReadBytes('\n'); err != nil { + l.consumer.consumeError(err) + return false + } + l.consumer.consumeBulkString(data[0:n]) + } + return true +} + +func (l *lexer) array(line []byte) bool { + var i int64 + var err error + i, err = strconv.ParseInt(string(line[1:]), 10, 64) + if err != nil { + l.consumer.consumeError(err) + return false + } + l.consumer.consumeArray(i) + return true +} + +func (l *lexer) lex() { + for line := l.nextLine(); line != nil && l.dispatch(line); line = l.nextLine() { + } +} + +type commandExecutor interface { + get([]byte) + set(key, block []byte) + startTransaction() + commitTransaction() +} + +type commandConsumer struct { + exec commandExecutor + missing int64 + elements []interface{} +} + +func newCommandConsumer(exec commandExecutor) *commandConsumer { + return &commandConsumer{ + exec, + 0, + []interface{}{}} +} + +func (cc *commandConsumer) push(i interface{}) { + cc.elements = append(cc.elements, i) + cc.missing-- + if cc.missing <= 0 { + cc.missing = 0 + cc.execute() + cc.elements = []interface{}{} + } +} + +func asString(i interface{}) string { + switch i.(type) { + case string: + return i.(string) + case []byte: + return string(i.([]byte)) + } + return fmt.Sprintf("%s", i) +} + +func (cc *commandConsumer) execute() { + l := len(cc.elements) + log.Printf("command length: %d", l) + if l < 1 { + log.Printf("WARN: Too less argument for command") + return + } + cmd := asString(cc.elements[0]) + switch cmd { + case "HGET": + if l < 3 { + log.Println("WARN: Missing argment for HGET") + return + } + // ignore hash + if block, ok := cc.elements[2].([]byte); ok { + log.Printf("HGET %d", cc.elements[2]) + cc.exec.get(block) + } else { + log.Println("WARN HGET data is not a byte slice") + } + case "HSET": + if l < 4 { + log.Println("WARN: Missing argment for HSET") + return + } + // ignore hash + key, ok1 := cc.elements[2].([]byte) + block, ok2 := cc.elements[3].([]byte) + + if !ok1 || !ok2 { + log.Printf("WARM HSET key or data is not a byte slice") + return + } + cc.exec.set(key, block) + case "MULTI": + cc.exec.startTransaction() + case "EXEC": + cc.exec.commitTransaction() + default: + log.Printf("UNKOWN command: %s", cmd) + } +} + +func (cc *commandConsumer) consumeSimpleString(s string) { + log.Printf("simple string: %s", s) + cc.push(s) +} + +func shorten(data []byte) string { + if len(data) > 10 { + return fmt.Sprintf("%.10q...", data) + } + return fmt.Sprintf("%q", data) +} + +func (cc *commandConsumer) consumeBulkString(data []byte) { + s := shorten(data) + log.Printf("buld string: len = %d: %s", len(data), s) + cc.push(data) +} + +func (cc *commandConsumer) consumeInteger(i int64) { + log.Printf("integer: %d", i) + cc.push(i) +} + +func (cc *commandConsumer) consumeError(err error) { + log.Printf("error: %s", err) +} + +func (cc *commandConsumer) consumeArray(i int64) { + log.Printf("array: %d", i) + if cc.missing > 0 { + log.Println("WARN: Nested arrays are not supported!") + return + } + if i < 0 { + log.Println("Null arrays are not supported") + return + } + cc.missing = i +} + +type connection struct { + conn net.Conn + db *leveldb.DB + tx *leveldb.WriteBatch + intArray []int +} + +func newConnection(conn net.Conn, db *leveldb.DB) *connection { + return &connection{ + conn, + db, + nil, + []int{}} +} + +func (c *connection) run() { + defer c.conn.Close() + cc := newCommandConsumer(c) + r := bufio.NewReader(c.conn) + lexer := newLexer(r, cc) + lexer.lex() +} + +func (c *connection) get(key []byte) { + log.Printf("client requested block: %q", key) + var err error + var data []byte + ro := leveldb.NewReadOptions() + defer ro.Close() + if data, err = c.db.Get(ro, key); err != nil { + log.Printf("Something is wrong with db: %s", err) + if err = c.writeError(); err != nil { + log.Printf("Send message to client failed: %s", err) + } + } else { + if err = c.writeBlock(data); err != nil { + log.Printf("Send message to client failed: %s", err) + } + } +} + +func (c *connection) set(key, block []byte) { + log.Printf("client wants to store block: %q", key) + + var err error + var exists int + + if exists, err = c.keyExists(key); err != nil { + log.Printf("Something is wrong with db: %s", err) + if err = c.writeError(); err != nil { + log.Printf("Writing message to client failed: %s", err) + } + return + } + + if c.tx != nil { + c.tx.Put(key, block) + c.intArray = append(c.intArray, exists) + if err = c.writeQueued(); err != nil { + log.Printf("Writing message to client failed: %s", err) + } + return + } else { + wo := leveldb.NewWriteOptions() + defer wo.Close() + if err = c.db.Put(wo, key, block); err != nil { + log.Printf("Something is wrong with db: %s", err) + if err = c.writeError(); err != nil { + log.Printf("Writing message to client failed: %s", err) + } + return + } + if err = c.writeInteger(exists); err != nil { + log.Printf("Writing message to client failed: %s", err) + } + } +} + +func (c *connection) keyExists(key []byte) (exists int, err error) { + ro := leveldb.NewReadOptions() + defer ro.Close() + var data []byte + if data, err = c.db.Get(ro, key); err != nil { + return + } + if data != nil { + exists = 1 + } else { + exists = 0 + } + return +} + +func (c *connection) startTransaction() { + if c.tx != nil { + log.Println("WARN: Already running transaction.") + } else { + c.tx = leveldb.NewWriteBatch() + } + if err := c.writeOk(); err != nil { + log.Printf("Writing message to client failed: %s", err) + } +} + +func (c *connection) commitTransaction() { + var err error + if c.tx == nil { + if err = c.writeEmptyArray(); err != nil { + log.Printf("Writing message to client failed: %s", err) + } + return + } + tx := c.tx + c.tx = nil + defer tx.Close() + arr := c.intArray + c.intArray = []int{} + wo := leveldb.NewWriteOptions() + defer wo.Close() + if err = c.db.Write(wo, tx); err != nil { + log.Printf("Something went wrong in writing transaction: %s", err) + if err = c.writeError(); err != nil { + log.Printf("Writing message to client failed: %s", err) + } + return + } + if err = c.writeIntegerArray(arr); err != nil { + log.Printf("Writing message to client failed: %s", err) + } +} + +var ( + redisOk = []byte("+OK\r\n") + redisDbError = []byte("-FAIL\r\n") + redisNoSuchBlock = []byte("$-1\r\n") + redisCrnl = []byte("\r\n") + redisEmptyArray = []byte("*0\r\n") + redisQueued = []byte("+QUEUED\r\n") +) + +func (c *connection) writeError() (err error) { + _, err = c.conn.Write(redisDbError) + return +} + +func (c *connection) writeEmptyArray() (err error) { + _, err = c.conn.Write(redisDbError) + return +} + +func (c *connection) writeInteger(v int) (err error) { + _, err = c.conn.Write([]byte(fmt.Sprintf(":%d\r\n", v))) + return +} + +func (c *connection) writeIntegerArray(arr []int) (err error) { + if _, err = c.conn.Write([]byte(fmt.Sprintf("*%d\r\n", len(arr)))); err != nil { + return + } + for x := range arr { + if err = c.writeInteger(x); err != nil { + return + } + } + return +} + +func (c *connection) writeOk() (err error) { + _, err = c.conn.Write(redisOk) + return +} + +func (c *connection) writeQueued() (err error) { + _, err = c.conn.Write(redisQueued) + return +} + +func (c *connection) writeBlock(data []byte) (err error) { + con := c.conn + if data == nil { + _, err = con.Write(redisNoSuchBlock) + } else { + if _, err = con.Write([]byte(fmt.Sprintf("$%d\r\n", len(data)))); err != nil { + return + } + if _, err = con.Write(data); err != nil { + return + } + _, err = con.Write(redisCrnl) + } + return +} + +func main() { + + var port int + var host string + var cacheSize int + + flag.IntVar(&port, "port", 6379, "port to bind") + flag.StringVar(&host, "host", "", "host to bind") + flag.IntVar(&cacheSize, "cache", 32, "cache size in MB") + flag.Parse() + + args := flag.Args() + + if len(args) < 1 { + log.Fatal("Missing path to world") + } + + cache := leveldb.NewLRUCache(cacheSize * 1024 * 1024) + defer cache.Close() + + opts := leveldb.NewOptions() + opts.SetCache(cache) + opts.SetCreateIfMissing(true) + + var err error + var db *leveldb.DB + + if db, err = leveldb.Open(args[0], opts); err != nil { + log.Fatal(err) + } + defer db.Close() + + var listener net.Listener + + listener, err = net.Listen("tcp", fmt.Sprintf("%s:%d", host, port)) + if err != nil { + log.Fatal(err) + } + defer listener.Close() + + connChan := make(chan net.Conn) + defer close(connChan) + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, os.Kill) + + go func() { + log.Println("Server started") + for { + conn, err := listener.Accept() + if err != nil { + log.Fatal(err) + } + log.Printf("Client accepted: %s", conn) + connChan <- conn + } + }() + + for { + select { + case conn := <-connChan: + log.Printf("New connection %s", conn) + go newConnection(conn, db).run() + case <-sigChan: + log.Println("Shutting down") + return + } + } +}