diff --git a/imapclient/fetch.go b/imapclient/fetch.go index 74d95f13..a6662f1b 100644 --- a/imapclient/fetch.go +++ b/imapclient/fetch.go @@ -652,7 +652,7 @@ func (c *Client) handleFetch(seqNum uint32) error { var ( item FetchItemData - done chan struct{} + done chan error ) switch attName { case "FLAGS": @@ -738,7 +738,7 @@ func (c *Client) handleFetch(seqNum uint32) error { var fetchLit imap.LiteralReader if lit != nil { - done = make(chan struct{}) + done = make(chan error) fetchLit = &fetchLiteralReader{ LiteralReader: lit, ch: done, @@ -817,7 +817,9 @@ func (c *Client) handleFetch(seqNum uint32) error { } items <- item if done != nil { - <-done + if err := <-done; err != nil { + return err + } c.setReadTimeout(respReadTimeout) } return nil @@ -1313,12 +1315,18 @@ func readSectionPart(dec *imapwire.Decoder) (part []int, dot bool) { type fetchLiteralReader struct { *imapwire.LiteralReader - ch chan<- struct{} + ch chan<- error } func (lit *fetchLiteralReader) Read(b []byte) (int, error) { n, err := lit.LiteralReader.Read(b) - if err == io.EOF && lit.ch != nil { + if err == nil { + return n, nil + } + if lit.ch != nil { + if err != io.EOF { + lit.ch <- err + } close(lit.ch) lit.ch = nil } diff --git a/internal/imapwire/decoder.go b/internal/imapwire/decoder.go index cfd2995c..2efb975c 100644 --- a/internal/imapwire/decoder.go +++ b/internal/imapwire/decoder.go @@ -577,7 +577,7 @@ func (dec *Decoder) Literal(ptr *string) bool { } if dec.CheckBufferedLiteralFunc != nil { if err := dec.CheckBufferedLiteralFunc(lit.Size(), nonSync); err != nil { - lit.cancel() + lit.cancel(nil) return false } } @@ -607,7 +607,7 @@ func (dec *Decoder) LiteralReader() (lit *LiteralReader, nonSync, ok bool) { lit = &LiteralReader{ dec: dec, size: size, - r: io.LimitReader(dec.r, size), + r: newLimitReader(dec.r, int(size)), } return lit, nonSync, true } @@ -639,16 +639,24 @@ func (lit *LiteralReader) Size() int64 { func (lit *LiteralReader) Read(b []byte) (int, error) { n, err := lit.r.Read(b) - if err == io.EOF { - lit.cancel() + if err != nil { + if err == io.EOF { + lit.cancel(nil) + } else { + lit.cancel(err) + } } return n, err } -func (lit *LiteralReader) cancel() { +func (lit *LiteralReader) cancel(err error) { if lit.dec == nil { return } + + if err != nil { + lit.dec.err = err + } lit.dec.literal = false lit.dec = nil } diff --git a/internal/imapwire/limit_reader.go b/internal/imapwire/limit_reader.go new file mode 100644 index 00000000..c35359e2 --- /dev/null +++ b/internal/imapwire/limit_reader.go @@ -0,0 +1,45 @@ +package imapwire + +import "io" + +// A reader that returns io.EOF after a specified number of bytes. +// +// If EOF is received before the requested number of bytes, io.ErrUnexpectedEOF is returned. +type limitReader struct { + wrapped io.Reader + count int +} + +func (reader *limitReader) Read(destination []byte) (int, error) { + if reader.count == 0 { + return 0, io.EOF + } + if len(destination) == 0 { + return 0, nil + } + if len(destination) > reader.count { + destination = destination[:reader.count] + } + read, err := reader.wrapped.Read(destination) + reader.count -= read + if err == nil { + if reader.count == 0 { + return read, io.EOF + } + return read, nil + } + if err == io.EOF { + if reader.count == 0 { + return read, io.EOF + } + return read, io.ErrUnexpectedEOF + } + return read, err +} + +func newLimitReader(reader io.Reader, count int) *limitReader { + return &limitReader{ + wrapped: reader, + count: count, + } +}