diff --git a/decode.go b/decode.go index d6e9dae..f050d56 100644 --- a/decode.go +++ b/decode.go @@ -21,7 +21,8 @@ func Unmarshal(data []byte, v interface{}) error { // A Decoder reads and decodes fixed width data from an input stream. type Decoder struct { - data *bufio.Reader + scanner *bufio.Scanner + lineTerminator []byte done bool useCodepointIndices bool @@ -31,9 +32,12 @@ type Decoder struct { // NewDecoder returns a new decoder that reads from r. func NewDecoder(r io.Reader) *Decoder { - return &Decoder{ - data: bufio.NewReader(r), + dec := &Decoder{ + scanner: bufio.NewScanner(r), + lineTerminator: []byte("\n"), } + dec.scanner.Split(dec.scan) + return dec } // An InvalidUnmarshalError describes an invalid argument passed to Unmarshal. @@ -178,19 +182,40 @@ func findFirstMultiByteChar(data string) int { return len(data) } -func (d *Decoder) readLine(v reflect.Value) (err error, ok bool) { - line, err := d.data.ReadString('\n') - if err != nil && err != io.EOF { - return err, false +// SetLineTerminator sets the character(s) that will be used to terminate lines. +// +// The default value is "\n". +func (d *Decoder) SetLineTerminator(lineTerminator []byte) { + if len(lineTerminator) > 0 { + d.lineTerminator = lineTerminator } - if err == io.EOF { - d.done = true +} - if len(line) <= 0 || line[0] == '\n' { - // skip last empty lines - return nil, false - } +func (d *Decoder) scan(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + if i := bytes.Index(data, d.lineTerminator); i >= 0 { + // We have a full newline-terminated line. + return i + len(d.lineTerminator), data[0:i], nil + } + // If we're at EOF, we have a final, non-terminated line. Return it. + if atEOF { + return len(data), data, nil } + // Request more data. + return 0, nil, nil +} + +func (d *Decoder) readLine(v reflect.Value) (err error, ok bool) { + ok = d.scanner.Scan() + if !ok { + d.done = true + return nil, false + } + + line := string(d.scanner.Bytes()) + rawValue, err := newRawValue(line, d.useCodepointIndices) if err != nil { return diff --git a/decode_test.go b/decode_test.go index 5ce24ce..1cdbed0 100644 --- a/decode_test.go +++ b/decode_test.go @@ -332,3 +332,79 @@ func TestNewRawValue(t *testing.T) { }) } } + +func TestLineSeparator(t *testing.T) { + // allTypes contains a field with all current supported types. + type allTypes struct { + String string `fixed:"1,5"` + Int int `fixed:"6,10"` + Float float64 `fixed:"11,15"` + TextUnmarshaler EncodableString `fixed:"16,20"` + } + for _, tt := range []struct { + name string + rawValue []byte + target interface{} + expected interface{} + shouldErr bool + lineTerminator []byte + }{ + { + name: "CR line endings", + rawValue: []byte("foo 123 1.2 bar" + "\n" + "bar 321 2.1 foo"), + target: &[]allTypes{}, + expected: &[]allTypes{ + {"foo", 123, 1.2, EncodableString{"bar", nil}}, + {"bar", 321, 2.1, EncodableString{"foo", nil}}, + }, + shouldErr: false, + lineTerminator: []byte{}, + }, + { + name: "CR line endings", + rawValue: []byte("f\ro 123 1.2 bar" + "\n" + "bar 321 2.1 foo"), + target: &[]allTypes{}, + expected: &[]allTypes{ + {"f\ro", 123, 1.2, EncodableString{"bar", nil}}, + {"bar", 321, 2.1, EncodableString{"foo", nil}}, + }, + shouldErr: false, + lineTerminator: []byte("\n"), + }, + { + name: "CRLF line endings", + rawValue: []byte("f\no 123 1.2 bar" + "\r\n" + "bar 321 2.1 foo"), + target: &[]allTypes{}, + expected: &[]allTypes{ + {"f\no", 123, 1.2, EncodableString{"bar", nil}}, + {"bar", 321, 2.1, EncodableString{"foo", nil}}, + }, + shouldErr: false, + lineTerminator: []byte("\r\n"), + }, + { + name: "LF line endings", + rawValue: []byte("f\no 123 1.2 bar" + "\r" + "bar 321 2.1 foo"), + target: &[]allTypes{}, + expected: &[]allTypes{ + {"f\no", 123, 1.2, EncodableString{"bar", nil}}, + {"bar", 321, 2.1, EncodableString{"foo", nil}}, + }, + shouldErr: false, + lineTerminator: []byte("\r"), + }, + } { + t.Run(tt.name, func(t *testing.T) { + dec := NewDecoder(bytes.NewReader(tt.rawValue)) + dec.SetLineTerminator(tt.lineTerminator) + err := dec.Decode(tt.target) + if tt.shouldErr != (err != nil) { + t.Errorf("Unmarshal() err want %v, have %v (%v)", tt.shouldErr, err != nil, err) + } + if !tt.shouldErr && !reflect.DeepEqual(tt.target, tt.expected) { + t.Errorf("Unmarshal() want %+v, have %+v", tt.expected, tt.target) + } + + }) + } +}