Skip to content

Commit

Permalink
Change Identify API to accept only an io.Reader (#322)
Browse files Browse the repository at this point in the history
* [feat] don't require io.Seeker for identify

* tidy up

* Refactor and simplify with some bug fixes

* Clarify returned Reader in godoc comment

Co-authored-by: Matthew Holt <[email protected]>
  • Loading branch information
jhwz and mholt authored Mar 17, 2022
1 parent 9e7a9a7 commit 4fc750e
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 41 deletions.
141 changes: 104 additions & 37 deletions formats.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package archiver

import (
"bytes"
"context"
"errors"
"fmt"
Expand All @@ -25,20 +26,28 @@ func RegisterFormat(format Format) {
// value can be type-asserted to ascertain its capabilities.
//
// If no matching formats were found, special error ErrNoMatch is returned.
func Identify(filename string, stream io.ReadSeeker) (Format, error) {
//
// The returned io.Reader will always be non-nil and will read from the
// same point as the reader which was passed in; it should be used in place
// of the input stream after calling Identify() because it preserves and
// re-reads the bytes that were already read during the identification
// process.
func Identify(filename string, stream io.Reader) (Format, io.Reader, error) {
var compression Compression
var archival Archival

rewindableStream := newRewindReader(stream)

// try compression format first, since that's the outer "layer"
for name, format := range formats {
cf, isCompression := format.(Compression)
if !isCompression {
continue
}

matchResult, err := identifyOne(format, filename, stream, nil)
matchResult, err := identifyOne(format, filename, rewindableStream, nil)
if err != nil {
return nil, fmt.Errorf("matching %s: %w", name, err)
return nil, rewindableStream.reader(), fmt.Errorf("matching %s: %w", name, err)
}

// if matched, wrap input stream with decompression
Expand All @@ -56,9 +65,9 @@ func Identify(filename string, stream io.ReadSeeker) (Format, error) {
continue
}

matchResult, err := identifyOne(format, filename, stream, compression)
matchResult, err := identifyOne(format, filename, rewindableStream, compression)
if err != nil {
return nil, fmt.Errorf("matching %s: %w", name, err)
return nil, rewindableStream.reader(), fmt.Errorf("matching %s: %w", name, err)
}

if matchResult.Matched() {
Expand All @@ -67,57 +76,45 @@ func Identify(filename string, stream io.ReadSeeker) (Format, error) {
}
}

// the stream should be rewound by identifyOne
bufferedStream := rewindableStream.reader()
switch {
case compression != nil && archival == nil:
return compression, nil
return compression, bufferedStream, nil
case compression == nil && archival != nil:
return archival, nil
return archival, bufferedStream, nil
case compression != nil && archival != nil:
return CompressedArchive{compression, archival}, nil
return CompressedArchive{compression, archival}, bufferedStream, nil
default:
return nil, ErrNoMatch
return nil, bufferedStream, ErrNoMatch
}
}

func identifyOne(format Format, filename string, stream io.ReadSeeker, comp Compression) (MatchResult, error) {
if stream == nil {
// shimming an empty stream is easier than hoping every format's
// implementation of Match() expects and handles a nil stream
stream = strings.NewReader("")
}

// reset stream position to beginning, then restore current position when done
previousOffset, err := stream.Seek(0, io.SeekCurrent)
if err != nil {
return MatchResult{}, err
}
_, err = stream.Seek(0, io.SeekStart)
if err != nil {
return MatchResult{}, err
}
defer stream.Seek(previousOffset, io.SeekStart)
func identifyOne(format Format, filename string, stream *rewindReader, comp Compression) (mr MatchResult, err error) {
defer stream.rewind()

// if looking within a compressed format, wrap the stream in a
// reader that can decompress it so we can match the "inner" format
// (yes, we have to make a new reader every time we do a match,
// because we reset/seek the stream each time and that can mess up
// the compression reader's state if we don't discard it also)
if comp != nil {
decompressedStream, err := comp.OpenReader(stream)
if err != nil {
return MatchResult{}, err
decompressedStream, openErr := comp.OpenReader(stream)
if openErr != nil {
return MatchResult{}, openErr
}
defer decompressedStream.Close()
stream = struct {
io.Reader
io.Seeker
}{
Reader: decompressedStream,
Seeker: stream,
}
mr, err = format.Match(filename, decompressedStream)
} else {
mr, err = format.Match(filename, stream)
}

return format.Match(filename, stream)
// if the error is EOF, we can just ignore it.
// Just means we have a small input file.
if errors.Is(err, io.EOF) {
err = nil
}
return mr, err
}

// readAtMost reads at most n bytes from the stream. A nil, empty, or short
Expand Down Expand Up @@ -256,6 +253,76 @@ type MatchResult struct {
// Matched returns true if a match was made by either name or stream.
func (mr MatchResult) Matched() bool { return mr.ByName || mr.ByStream }

// rewindReader is a Reader that can be rewound (reset) to re-read what
// was already read and then continue to read more from the underlying
// stream. When no more rewinding is necessary, call reader() to get a
// new reader that first reads the buffered bytes, then continues to
// read from the stream. This is useful for "peeking" a stream an
// arbitrary number of bytes. Loosely based on the Connection type
// from https://github.com/mholt/caddy-l4.
type rewindReader struct {
io.Reader
buf *bytes.Buffer
bufReader io.Reader
}

func newRewindReader(r io.Reader) *rewindReader {
return &rewindReader{
Reader: r,
buf: new(bytes.Buffer),
}
}

func (rr *rewindReader) Read(p []byte) (n int, err error) {
// if there is a buffer we should read from, start
// with that; we only read from the underlying stream
// after the buffer has been "depleted"
if rr.bufReader != nil {
n, err = rr.bufReader.Read(p)
if err == io.EOF {
rr.bufReader = nil
err = nil
}
if n == len(p) {
return
}
}

// buffer has been "depleted" so read from
// underlying connection
nr, err := rr.Reader.Read(p[n:])

// anything that was read needs to be written to
// the buffer, even if there was an error
if nr > 0 {
if nw, errw := rr.buf.Write(p[n : n+nr]); errw != nil {
return nw, errw
}
}

// up to now, n was how many bytes were read from
// the buffer, and nr was how many bytes were read
// from the stream; add them to return total count
n += nr

return
}

// rewind resets the stream to the beginning by causing
// Read() to start reading from the beginning of the
// buffered bytes.
func (rr *rewindReader) rewind() {
rr.bufReader = bytes.NewReader(rr.buf.Bytes())
}

// reader returns a reader that reads first from the buffered
// bytes, then from the underlying stream. After calling this,
// no more rewinding is allowed since reads from the stream are
// not recorded, so rewinding properly is impossible.
func (rr *rewindReader) reader() io.Reader {
return io.MultiReader(bytes.NewReader(rr.buf.Bytes()), rr.Reader)
}

// ErrNoMatch is returned if there are no matching formats.
var ErrNoMatch = fmt.Errorf("no formats matched")

Expand Down
106 changes: 103 additions & 3 deletions formats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,111 @@ import (
"context"
"io"
"io/fs"
"math/rand"
"os"
"strings"
"testing"
"time"
)

func TestRewindReader(t *testing.T) {
data := "the header\nthe body\n"

r := newRewindReader(strings.NewReader(data))

buf := make([]byte, 10) // enough for 'the header'

// test rewinding reads
for i := 0; i < 10; i++ {
r.rewind()
n, err := r.Read(buf)
if err != nil {
t.Fatalf("Read failed: %s", err)
}
if string(buf[:n]) != "the header" {
t.Fatalf("iteration %d: expected 'the header' but got '%s' (n=%d)", i, string(buf[:n]), n)
}
}

// get the reader from header reader and make sure we can read all of the data out
r.rewind()
finalReader := r.reader()
buf = make([]byte, len(data))
n, err := io.ReadFull(finalReader, buf)
if err != nil {
t.Fatalf("ReadFull failed: %s (n=%d)", err, n)
}
if string(buf) != data {
t.Fatalf("expected '%s' but got '%s'", string(data), string(buf))
}
}

func TestCompression(t *testing.T) {
seed := time.Now().UnixNano()
t.Logf("seed: %d", seed)
r := rand.New(rand.NewSource(seed))

contents := make([]byte, 1024)
r.Read(contents)

compressed := new(bytes.Buffer)

testOK := func(t *testing.T, comp Compression, testFilename string) {
// compress into buffer
compressed.Reset()
wc, err := comp.OpenWriter(compressed)
checkErr(t, err, "opening writer")
_, err = wc.Write(contents)
checkErr(t, err, "writing contents")
checkErr(t, wc.Close(), "closing writer")

// make sure Identify correctly chooses this compression method
format, stream, err := Identify(testFilename, compressed)
checkErr(t, err, "identifying")
if format.Name() != comp.Name() {
t.Fatalf("expected format %s but got %s", comp.Name(), format.Name())
}

// read the contents back out and compare
decompReader, err := format.(Decompressor).OpenReader(stream)
checkErr(t, err, "opening with decompressor '%s'", format.Name())
data, err := io.ReadAll(decompReader)
checkErr(t, err, "reading decompressed data")
checkErr(t, decompReader.Close(), "closing decompressor")
if !bytes.Equal(data, contents) {
t.Fatalf("not equal to original")
}
}

var cannotIdentifyFromStream = map[string]bool{Brotli{}.Name(): true}

for _, f := range formats {
// only test compressors
comp, ok := f.(Compression)
if !ok {
continue
}

t.Run(f.Name()+"_with_extension", func(t *testing.T) {
testOK(t, comp, "file"+f.Name())
})
if !cannotIdentifyFromStream[f.Name()] {
t.Run(f.Name()+"_without_extension", func(t *testing.T) {
testOK(t, comp, "")
})
}
}
}

func checkErr(t *testing.T, err error, msgFmt string, args ...interface{}) {
t.Helper()
if err == nil {
return
}
args = append(args, err)
t.Fatalf(msgFmt+": %s", args...)
}

func TestIdentifyDoesNotMatchContentFromTrimmedKnownHeaderHaving0Suffix(t *testing.T) {
// Using the outcome of `n, err := io.ReadFull(stream, buf)` without minding n
// may lead to a mis-characterization for cases with known header ending with 0x0
Expand Down Expand Up @@ -41,7 +141,7 @@ func TestIdentifyDoesNotMatchContentFromTrimmedKnownHeaderHaving0Suffix(t *testi
}
headerTrimmed := tt.header[:headerLen-1]
stream := bytes.NewReader(headerTrimmed)
got, err := Identify("", stream)
got, _, err := Identify("", stream)
if got != nil {
t.Errorf("no Format expected for trimmed know %s header: found Format= %v", tt.name, got.Name())
return
Expand Down Expand Up @@ -84,7 +184,7 @@ func TestIdentifyCanAssessSmallOrNoContent(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Identify("", tt.args.stream)
got, _, err := Identify("", tt.args.stream)
if got != nil {
t.Errorf("no Format expected for non archive and not compressed stream: found Format= %v", got.Name())
return
Expand Down Expand Up @@ -274,7 +374,7 @@ func TestIdentifyFindFormatByStreamContent(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stream := bytes.NewReader(compress(t, tt.compressorName, tt.content, tt.openCompressionWriter))
got, err := Identify("", stream)
got, _, err := Identify("", stream)
if err != nil {
t.Fatalf("should have found a corresponding Format: err :=%+v", err)
return
Expand Down
2 changes: 1 addition & 1 deletion fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func FileSystem(root string) (fs.FS, error) {
return nil, err
}
defer file.Close()
format, err := Identify(filepath.Base(root), file)
format, _, err := Identify(filepath.Base(root), file)
if err != nil && !errors.Is(err, ErrNoMatch) {
return nil, err
}
Expand Down

0 comments on commit 4fc750e

Please sign in to comment.