diff --git a/protocol/packet/decoder.go b/protocol/packet/decoder.go index 53bb11b..827362f 100644 --- a/protocol/packet/decoder.go +++ b/protocol/packet/decoder.go @@ -15,22 +15,34 @@ import ( "bytes" "encoding/json" "fmt" + "io" + + "github.com/pkg/errors" ) func Unmarshal(data []byte, p *Packet) error { - if len(data) == 0 { - return fmt.Errorf("no data") + return NewDecoder(bytes.NewBuffer(data)).Decode(p) +} + +type Decoder struct { + r io.Reader + j *json.Decoder +} + +func NewDecoder(r io.Reader) *Decoder { + return &Decoder{ + r: r, + j: json.NewDecoder(r), } +} +func (d *Decoder) Decode(p *Packet) error { if p == nil { return fmt.Errorf("no packet") } - if idx := bytes.LastIndexByte(data, '\n'); idx != -1 { - data = data[0:idx] - } - if err := json.Unmarshal(data, p); err != nil { - return err + if err := d.j.Decode(p); err != nil { + return errors.Wrap(err, "failed to decode body") } if p.Id == uint64(0) || p.Type == "" || p.Body == nil { diff --git a/protocol/packet/decoder_test.go b/protocol/packet/decoder_test.go index 0668839..a16d798 100644 --- a/protocol/packet/decoder_test.go +++ b/protocol/packet/decoder_test.go @@ -12,6 +12,7 @@ package packet_test import ( + "strings" "testing" "github.com/stretchr/testify/assert" @@ -37,3 +38,34 @@ func TestUnmarshal(t *testing.T) { Body: []byte("{}"), }) } + +func TestDecoder(t *testing.T) { + input := ` +{"id": 123, "type": "foo","body":{}} +{"id": 456, "type": "bar","body":{"123": 123}} +{"id": 678, "type": "baz" +` + d := packet.NewDecoder(strings.NewReader(input)) + + p := packet.Packet{} + err := d.Decode(&p) + assert.NoError(t, err) + assert.Equal(t, p, packet.Packet{ + Id: uint64(123), + Type: "foo", + Body: []byte("{}"), + }) + + p = packet.Packet{} + err = d.Decode(&p) + assert.NoError(t, err) + assert.Equal(t, p, packet.Packet{ + Id: uint64(456), + Type: "bar", + Body: []byte(`{"123": 123}`), + }) + + p = packet.Packet{} + err = d.Decode(&p) + assert.Error(t, err) +} diff --git a/protocol/packet/encoder.go b/protocol/packet/encoder.go index 451d073..600852a 100644 --- a/protocol/packet/encoder.go +++ b/protocol/packet/encoder.go @@ -12,9 +12,12 @@ package packet import ( + "bytes" "encoding/json" - "fmt" + "io" "time" + + "github.com/pkg/errors" ) var getId = func() uint64 { @@ -22,30 +25,57 @@ var getId = func() uint64 { } func Marshal(p *Packet) ([]byte, error) { + b := &bytes.Buffer{} + enc := NewEncoder(b) + if err := enc.Encode(p); err != nil { + return nil, err + } + return b.Bytes(), nil +} + +type Encoder struct { + w io.Writer + j *json.Encoder +} + +func NewEncoder(w io.Writer) *Encoder { + return &Encoder{ + w: w, + j: json.NewEncoder(w), + } +} + +type auxPacket struct { + Packet + Body interface{} `json:"body"` +} + +func (e *Encoder) Encode(p *Packet) error { if p == nil { - return nil, fmt.Errorf("no packet") + return errors.New("no packet") } if p.Type == "" { - return nil, fmt.Errorf("packet type not set") + return errors.New("packet type not set") } - body, err := json.Marshal(p.auxBody) - if err != nil { - return nil, fmt.Errorf("failed to encode body: %v", err) - } id := p.Id if id == 0 { id = getId() } - data, err := json.Marshal(Packet{ - Id: id, - Type: p.Type, + + body := p.auxBody + // encodes packet and appends a newline character + err := e.j.Encode(auxPacket{ + Packet: Packet{ + Id: id, + Type: p.Type, + }, Body: body, }) if err != nil { - return nil, err + return errors.Wrap(err, "failed to encode body") } - return append(data, '\n'), nil + return nil }