mirror of
1
Fork 0
forgejo/vendor/github.com/djherbis/buffer/multi.go

186 lines
3.9 KiB
Go

package buffer
import (
"bytes"
"encoding/gob"
"io"
"math"
)
type chain struct {
Buf BufferAt
Next BufferAt
}
type nopBufferAt struct {
Buffer
}
func (buf *nopBufferAt) ReadAt(p []byte, off int64) (int, error) {
panic("ReadAt not implemented")
}
func (buf *nopBufferAt) WriteAt(p []byte, off int64) (int, error) {
panic("WriteAt not implemented")
}
// toBufferAt converts a Buffer to a BufferAt with nop ReadAt and WriteAt funcs
func toBufferAt(buf Buffer) BufferAt {
return &nopBufferAt{Buffer: buf}
}
// NewMultiAt returns a BufferAt which is the logical concatenation of the passed BufferAts.
// The data in the buffers is shifted such that there is no non-empty buffer following
// a non-full buffer, this process is also run after every Read.
// If no buffers are passed, the returned Buffer is nil.
func NewMultiAt(buffers ...BufferAt) BufferAt {
if len(buffers) == 0 {
return nil
} else if len(buffers) == 1 {
return buffers[0]
}
buf := &chain{
Buf: buffers[0],
Next: NewMultiAt(buffers[1:]...),
}
buf.Defrag()
return buf
}
// NewMulti returns a Buffer which is the logical concatenation of the passed buffers.
// The data in the buffers is shifted such that there is no non-empty buffer following
// a non-full buffer, this process is also run after every Read.
// If no buffers are passed, the returned Buffer is nil.
func NewMulti(buffers ...Buffer) Buffer {
bufAt := make([]BufferAt, len(buffers))
for i, buf := range buffers {
bufAt[i] = toBufferAt(buf)
}
return NewMultiAt(bufAt...)
}
func (buf *chain) Reset() {
buf.Next.Reset()
buf.Buf.Reset()
}
func (buf *chain) Cap() (n int64) {
Next := buf.Next.Cap()
if buf.Buf.Cap() > math.MaxInt64-Next {
return math.MaxInt64
}
return buf.Buf.Cap() + Next
}
func (buf *chain) Len() (n int64) {
Next := buf.Next.Len()
if buf.Buf.Len() > math.MaxInt64-Next {
return math.MaxInt64
}
return buf.Buf.Len() + Next
}
func (buf *chain) Defrag() {
for !Full(buf.Buf) && !Empty(buf.Next) {
r := io.LimitReader(buf.Next, Gap(buf.Buf))
if _, err := io.Copy(buf.Buf, r); err != nil && err != io.EOF {
return
}
}
}
func (buf *chain) Read(p []byte) (n int, err error) {
n, err = buf.Buf.Read(p)
if len(p[n:]) > 0 && (err == nil || err == io.EOF) {
m, err := buf.Next.Read(p[n:])
n += m
if err != nil {
return n, err
}
}
buf.Defrag()
return n, err
}
func (buf *chain) ReadAt(p []byte, off int64) (n int, err error) {
if buf.Buf.Len() < off {
return buf.Next.ReadAt(p, off-buf.Buf.Len())
}
n, err = buf.Buf.ReadAt(p, off)
if len(p[n:]) > 0 && (err == nil || err == io.EOF) {
var m int
m, err = buf.Next.ReadAt(p[n:], 0)
n += m
}
return n, err
}
func (buf *chain) Write(p []byte) (n int, err error) {
if n, err = buf.Buf.Write(p); err == io.ErrShortWrite {
err = nil
}
p = p[n:]
if len(p) > 0 && err == nil {
m, err := buf.Next.Write(p)
n += m
if err != nil {
return n, err
}
}
return n, err
}
func (buf *chain) WriteAt(p []byte, off int64) (n int, err error) {
switch {
case buf.Buf.Cap() <= off: // past the end
return buf.Next.WriteAt(p, off-buf.Buf.Cap())
case buf.Buf.Cap() >= off+int64(len(p)): // fits in
return buf.Buf.WriteAt(p, off)
default: // partial fit
n, err = buf.Buf.WriteAt(p, off)
if len(p[n:]) > 0 && (err == nil || err == io.ErrShortWrite) {
var m int
m, err = buf.Next.WriteAt(p[n:], 0)
n += m
}
return n, err
}
}
func init() {
gob.Register(&chain{})
gob.Register(&nopBufferAt{})
}
func (buf *chain) MarshalBinary() ([]byte, error) {
b := bytes.NewBuffer(nil)
enc := gob.NewEncoder(b)
if err := enc.Encode(&buf.Buf); err != nil {
return nil, err
}
if err := enc.Encode(&buf.Next); err != nil {
return nil, err
}
return b.Bytes(), nil
}
func (buf *chain) UnmarshalBinary(data []byte) error {
b := bytes.NewBuffer(data)
dec := gob.NewDecoder(b)
if err := dec.Decode(&buf.Buf); err != nil {
return err
}
if err := dec.Decode(&buf.Next); err != nil {
return err
}
return nil
}