diff --git a/main.go b/main.go index 7ee96e6..926e861 100644 --- a/main.go +++ b/main.go @@ -2,255 +2,21 @@ package main import ( "bufio" - "bytes" "flag" "fmt" "log" "net" "os" "os/signal" - "strconv" leveldb "github.com/jmhodges/levigo" ) -type itemConsumer interface { - consumeInteger(int64) - consumeArray(int64) - consumeSimpleString(string) - consumeBulkString([]byte) - consumeError(error) -} - -type lexer struct { - reader *bufio.Reader - consumer itemConsumer -} - -func newLexer(reader *bufio.Reader, consumer itemConsumer) *lexer { - return &lexer{reader, consumer} -} - -func (l *lexer) nextLine() []byte { - line, err := l.reader.ReadBytes('\n') - if err != nil { - l.consumer.consumeError(err) - return nil - } - line = bytes.TrimRight(line, "\r\n") - log.Printf("%q", line) - return line -} - -func (l *lexer) dispatch(line []byte) bool { - if len(line) < 1 { - return false - } - switch line[0] { - case '-': - return true // ignore errors - case ':': - return l.integer(line) - case '+': - return l.simpleString(line) - case '$': - return l.bulkString(line) - case '*': - return l.array(line) - } - return true -} - -func (l *lexer) simpleString(line []byte) bool { - l.consumer.consumeSimpleString(string(line[1:])) - return true -} - -func (l *lexer) integer(line []byte) bool { - i, err := strconv.ParseInt(string(line[1:]), 10, 64) - if err != nil { - l.consumer.consumeError(err) - return false - } - l.consumer.consumeInteger(i) - return true -} - -func (l *lexer) bulkString(line []byte) bool { - var i int64 - var err error - i, err = strconv.ParseInt(string(line[1:]), 10, 64) - if err != nil { - l.consumer.consumeError(err) - return false - } - switch { - case i < 0: - l.consumer.consumeBulkString(nil) - case i == 0: - l.consumer.consumeBulkString([]byte{}) - default: - data := make([]byte, i, i) - var n int - if n, err = l.reader.Read(data); err != nil { - l.consumer.consumeError(err) - return false - } - if _, err = l.reader.ReadBytes('\n'); err != nil { - l.consumer.consumeError(err) - return false - } - l.consumer.consumeBulkString(data[0:n]) - } - return true -} - -func (l *lexer) array(line []byte) bool { - var i int64 - var err error - i, err = strconv.ParseInt(string(line[1:]), 10, 64) - if err != nil { - l.consumer.consumeError(err) - return false - } - l.consumer.consumeArray(i) - return true -} - -func (l *lexer) lex() { - for line := l.nextLine(); line != nil && l.dispatch(line); line = l.nextLine() { - } -} - -type commandExecutor interface { - get([]byte) - set(key, block []byte) - startTransaction() - commitTransaction() -} - -type commandConsumer struct { - exec commandExecutor - missing int64 - elements []interface{} -} - -func newCommandConsumer(exec commandExecutor) *commandConsumer { - return &commandConsumer{ - exec, - 0, - []interface{}{}} -} - -func (cc *commandConsumer) push(i interface{}) { - cc.elements = append(cc.elements, i) - cc.missing-- - if cc.missing <= 0 { - cc.missing = 0 - cc.execute() - cc.elements = []interface{}{} - } -} - -func asString(i interface{}) string { - switch i.(type) { - case string: - return i.(string) - case []byte: - return string(i.([]byte)) - } - return fmt.Sprintf("%s", i) -} - -func (cc *commandConsumer) execute() { - l := len(cc.elements) - log.Printf("command length: %d", l) - if l < 1 { - log.Printf("WARN: Too less argument for command") - return - } - cmd := asString(cc.elements[0]) - switch cmd { - case "HGET": - if l < 3 { - log.Println("WARN: Missing argment for HGET") - return - } - // ignore hash - if block, ok := cc.elements[2].([]byte); ok { - log.Printf("HGET %d", cc.elements[2]) - cc.exec.get(block) - } else { - log.Println("WARN HGET data is not a byte slice") - } - case "HSET": - if l < 4 { - log.Println("WARN: Missing argment for HSET") - return - } - // ignore hash - key, ok1 := cc.elements[2].([]byte) - block, ok2 := cc.elements[3].([]byte) - - if !ok1 || !ok2 { - log.Printf("WARM HSET key or data is not a byte slice") - return - } - cc.exec.set(key, block) - case "MULTI": - cc.exec.startTransaction() - case "EXEC": - cc.exec.commitTransaction() - default: - log.Printf("UNKOWN command: %s", cmd) - } -} - -func (cc *commandConsumer) consumeSimpleString(s string) { - log.Printf("simple string: %s", s) - cc.push(s) -} - -func shorten(data []byte) string { - if len(data) > 10 { - return fmt.Sprintf("%.10q...", data) - } - return fmt.Sprintf("%q", data) -} - -func (cc *commandConsumer) consumeBulkString(data []byte) { - s := shorten(data) - log.Printf("buld string: len = %d: %s", len(data), s) - cc.push(data) -} - -func (cc *commandConsumer) consumeInteger(i int64) { - log.Printf("integer: %d", i) - cc.push(i) -} - -func (cc *commandConsumer) consumeError(err error) { - log.Printf("error: %s", err) -} - -func (cc *commandConsumer) consumeArray(i int64) { - log.Printf("array: %d", i) - if cc.missing > 0 { - log.Println("WARN: Nested arrays are not supported!") - return - } - if i < 0 { - log.Println("Null arrays are not supported") - return - } - cc.missing = i -} - type connection struct { - conn net.Conn - db *leveldb.DB - tx *leveldb.WriteBatch - intArray []int + conn net.Conn + db *leveldb.DB + tx *leveldb.WriteBatch + intArray []int } func newConnection(conn net.Conn, db *leveldb.DB) *connection { @@ -258,18 +24,18 @@ func newConnection(conn net.Conn, db *leveldb.DB) *connection { conn, db, nil, - []int{}} + []int{}} } func (c *connection) run() { defer c.conn.Close() - cc := newCommandConsumer(c) + rce := NewRedisCommandExecutor(c) r := bufio.NewReader(c.conn) - lexer := newLexer(r, cc) - lexer.lex() + parser := NewRedisParser(r, rce) + parser.Parse() } -func (c *connection) get(key []byte) { +func (c *connection) Hget(hash, key []byte) bool { log.Printf("client requested block: %q", key) var err error var data []byte @@ -279,102 +45,115 @@ func (c *connection) get(key []byte) { log.Printf("Something is wrong with db: %s", err) if err = c.writeError(); err != nil { log.Printf("Send message to client failed: %s", err) + return false } } else { if err = c.writeBlock(data); err != nil { log.Printf("Send message to client failed: %s", err) + return false } } + return true } -func (c *connection) set(key, block []byte) { +func (c *connection) Hset(hash, key, block []byte) bool { log.Printf("client wants to store block: %q", key) - var err error - var exists int + var err error + var exists int - if exists, err = c.keyExists(key); err != nil { + if exists, err = c.keyExists(key); err != nil { log.Printf("Something is wrong with db: %s", err) - if err = c.writeError(); err != nil { - log.Printf("Writing message to client failed: %s", err) - } - return - } + if err = c.writeError(); err != nil { + log.Printf("Writing message to client failed: %s", err) + return false + } + return true + } - if c.tx != nil { - c.tx.Put(key, block) - c.intArray = append(c.intArray, exists) - if err = c.writeQueued(); err != nil { - log.Printf("Writing message to client failed: %s", err) - } - return - } else { - wo := leveldb.NewWriteOptions() - defer wo.Close() - if err = c.db.Put(wo, key, block); err != nil { - log.Printf("Something is wrong with db: %s", err) - if err = c.writeError(); err != nil { - log.Printf("Writing message to client failed: %s", err) - } - return - } - if err = c.writeInteger(exists); err != nil { - log.Printf("Writing message to client failed: %s", err) - } - } + if c.tx != nil { + c.tx.Put(key, block) + c.intArray = append(c.intArray, exists) + if err = c.writeQueued(); err != nil { + log.Printf("Writing message to client failed: %s", err) + return false + } + return true + } + wo := leveldb.NewWriteOptions() + defer wo.Close() + if err = c.db.Put(wo, key, block); err != nil { + log.Printf("Something is wrong with db: %s", err) + if err = c.writeError(); err != nil { + log.Printf("Writing message to client failed: %s", err) + return false + } + return true + } + if err = c.writeInteger(exists); err != nil { + log.Printf("Writing message to client failed: %s", err) + return false + } + return true } func (c *connection) keyExists(key []byte) (exists int, err error) { ro := leveldb.NewReadOptions() defer ro.Close() - var data []byte + var data []byte if data, err = c.db.Get(ro, key); err != nil { - return - } - if data != nil { - exists = 1 - } else { - exists = 0 - } - return + return + } + if data != nil { + exists = 1 + } else { + exists = 0 + } + return } -func (c *connection) startTransaction() { - if c.tx != nil { - log.Println("WARN: Already running transaction.") - } else { - c.tx = leveldb.NewWriteBatch() - } - if err := c.writeOk(); err != nil { - log.Printf("Writing message to client failed: %s", err) - } +func (c *connection) Multi() bool { + if c.tx != nil { + log.Println("WARN: Already running transaction.") + } else { + c.tx = leveldb.NewWriteBatch() + } + if err := c.writeOk(); err != nil { + log.Printf("Writing message to client failed: %s", err) + return false + } + return true } -func (c *connection) commitTransaction() { - var err error - if c.tx == nil { - if err = c.writeEmptyArray(); err != nil { - log.Printf("Writing message to client failed: %s", err) - } - return - } - tx := c.tx - c.tx = nil - defer tx.Close() - arr := c.intArray - c.intArray = []int{} - wo := leveldb.NewWriteOptions() - defer wo.Close() - if err = c.db.Write(wo, tx); err != nil { - log.Printf("Something went wrong in writing transaction: %s", err) - if err = c.writeError(); err != nil { - log.Printf("Writing message to client failed: %s", err) - } - return - } - if err = c.writeIntegerArray(arr); err != nil { - log.Printf("Writing message to client failed: %s", err) - } +func (c *connection) Exec() bool { + var err error + if c.tx == nil { + if err = c.writeEmptyArray(); err != nil { + log.Printf("Writing message to client failed: %s", err) + return false + } + return true + } + tx := c.tx + c.tx = nil + defer tx.Close() + arr := c.intArray + c.intArray = []int{} + wo := leveldb.NewWriteOptions() + defer wo.Close() + if err = c.db.Write(wo, tx); err != nil { + log.Printf("Something went wrong in writing transaction: %s", err) + if err = c.writeError(); err != nil { + log.Printf("Writing message to client failed: %s", err) + return false + } + return true + } + if err = c.writeIntegerArray(arr); err != nil { + log.Printf("Writing message to client failed: %s", err) + return false + } + return true } var ( @@ -397,20 +176,20 @@ func (c *connection) writeEmptyArray() (err error) { } func (c *connection) writeInteger(v int) (err error) { - _, err = c.conn.Write([]byte(fmt.Sprintf(":%d\r\n", v))) - return + _, err = c.conn.Write([]byte(fmt.Sprintf(":%d\r\n", v))) + return } func (c *connection) writeIntegerArray(arr []int) (err error) { - if _, err = c.conn.Write([]byte(fmt.Sprintf("*%d\r\n", len(arr)))); err != nil { - return - } - for x := range arr { - if err = c.writeInteger(x); err != nil { - return - } - } - return + if _, err = c.conn.Write([]byte(fmt.Sprintf("*%d\r\n", len(arr)))); err != nil { + return + } + for x := range arr { + if err = c.writeInteger(x); err != nil { + return + } + } + return } func (c *connection) writeOk() (err error) { diff --git a/parser.go b/parser.go new file mode 100644 index 0000000..f3dfa99 --- /dev/null +++ b/parser.go @@ -0,0 +1,240 @@ +package main + +import ( + "bufio" + "bytes" + "fmt" + "log" + "strconv" +) + +type RedisConsumer interface { + ConsumeInteger(int64) bool + ConsumeArray(int64) bool + ConsumeSimpleString(string) bool + ConsumeBulkString([]byte) bool + ConsumeError(error) bool +} + +type RedisParser struct { + reader *bufio.Reader + consumer RedisConsumer +} + +func NewRedisParser(reader *bufio.Reader, consumer RedisConsumer) *RedisParser { + return &RedisParser{ + reader: reader, + consumer: consumer} +} + +func (rp *RedisParser) Parse() { + for line := rp.nextLine(); line != nil && rp.dispatch(line); line = rp.nextLine() { + } +} + +func (rp *RedisParser) nextLine() []byte { + line, err := rp.reader.ReadBytes('\n') + if err != nil { + rp.consumer.ConsumeError(err) + return nil + } + line = bytes.TrimRight(line, "\r\n") + log.Printf("%q", line) + return line +} + +func (rp *RedisParser) dispatch(line []byte) bool { + if len(line) < 1 { + return false + } + switch line[0] { + case '-': + return true // ignore errors + case ':': + return rp.integer(line) + case '+': + return rp.simpleString(line) + case '$': + return rp.bulkString(line) + case '*': + return rp.array(line) + } + return true +} + +func (rp *RedisParser) simpleString(line []byte) bool { + return rp.consumer.ConsumeSimpleString(string(line[1:])) +} + +func (rp *RedisParser) integer(line []byte) bool { + i, err := strconv.ParseInt(string(line[1:]), 10, 64) + if err != nil { + return rp.consumer.ConsumeError(err) + } + return rp.consumer.ConsumeInteger(i) +} + +func (rp *RedisParser) bulkString(line []byte) bool { + var i int64 + var err error + i, err = strconv.ParseInt(string(line[1:]), 10, 64) + if err != nil { + return rp.consumer.ConsumeError(err) + } + switch { + case i < 0: + return rp.consumer.ConsumeBulkString(nil) + case i == 0: + return rp.consumer.ConsumeBulkString([]byte{}) + default: + data := make([]byte, i, i) + var n int + if n, err = rp.reader.Read(data); err != nil { + return rp.consumer.ConsumeError(err) + } + if _, err = rp.reader.ReadBytes('\n'); err != nil { + return rp.consumer.ConsumeError(err) + } + return rp.consumer.ConsumeBulkString(data[0:n]) + } +} + +func (rp *RedisParser) array(line []byte) bool { + var i int64 + var err error + i, err = strconv.ParseInt(string(line[1:]), 10, 64) + if err != nil { + return rp.consumer.ConsumeError(err) + } + return rp.consumer.ConsumeArray(i) +} + +type RedisCommands interface { + Hget(hash, key []byte) bool + Hset(hash, key, block []byte) bool + Multi() bool + Exec() bool +} + +type RedisCommandExecutor struct { + commands RedisCommands + missing int64 + args []interface{} +} + +func NewRedisCommandExecutor(commands RedisCommands) *RedisCommandExecutor { + return &RedisCommandExecutor{ + commands: commands, + missing: 0, + args: []interface{}{}} +} + +func (rce *RedisCommandExecutor) push(i interface{}) bool { + rce.args = append(rce.args, i) + rce.missing-- + if rce.missing <= 0 { + rce.missing = 0 + res := rce.execute() + rce.args = []interface{}{} + return res + } + return true +} + +func asString(i interface{}) string { + switch i.(type) { + case string: + return i.(string) + case []byte: + return string(i.([]byte)) + } + return fmt.Sprintf("%s", i) +} + +func (rce *RedisCommandExecutor) execute() bool { + l := len(rce.args) + if l < 1 { + log.Printf("WARN: Too less argument for command") + return false + } + cmd := asString(rce.args[0]) + switch cmd { + case "HGET": + if l < 3 { + log.Println("WARN: Missing argments for HGET") + return false + } + hash, ok1 := rce.args[1].([]byte) + key, ok2 := rce.args[2].([]byte) + if !ok1 || !ok2 { + log.Println("WARN HGET data are not byte slices.") + return false + } + return rce.commands.Hget(hash, key) + + case "HSET": + if l < 4 { + log.Println("WARN: Missing argments for HSET.") + return false + } + hash, ok1 := rce.args[1].([]byte) + key, ok2 := rce.args[2].([]byte) + block, ok3 := rce.args[3].([]byte) + + if !ok1 || !ok2 || !ok3 { + log.Printf("WARM HSET data are not byte slices,") + return false + } + return rce.commands.Hset(hash, key, block) + + case "MULTI": + return rce.commands.Multi() + + case "EXEC": + return rce.commands.Exec() + } + log.Printf("WARN unkown command: %s", cmd) + return false +} + +func (rce *RedisCommandExecutor) ConsumeSimpleString(s string) bool { + log.Printf("simple string: %s", s) + return rce.push(s) +} + +func shorten(data []byte) string { + if len(data) > 10 { + return fmt.Sprintf("%.10q...", data) + } + return fmt.Sprintf("%q", data) +} + +func (rce *RedisCommandExecutor) ConsumeBulkString(data []byte) bool { + s := shorten(data) + log.Printf("buld string: len = %d: %s", len(data), s) + return rce.push(data) +} + +func (rce *RedisCommandExecutor) ConsumeInteger(i int64) bool { + log.Printf("integer: %d", i) + return rce.push(i) +} + +func (rce *RedisCommandExecutor) ConsumeError(err error) bool { + log.Printf("error: %s", err) + return true +} + +func (rce *RedisCommandExecutor) ConsumeArray(i int64) bool { + log.Printf("array: %d", i) + if rce.missing > 0 { + log.Println("WARN: Nested arrays are not supported!") + return false + } + if i < 0 { + log.Println("Null arrays are not supported") + return false + } + rce.missing = i + return true +}