135 lines
2.8 KiB
Go
135 lines
2.8 KiB
Go
|
package fastcopy
|
||
|
|
||
|
import (
|
||
|
"io"
|
||
|
"sync"
|
||
|
_ "unsafe" // link to io.errInvalidWrite.
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
// global pool instance.
|
||
|
pool = CopyPool{size: 4096}
|
||
|
|
||
|
//go:linkname errInvalidWrite io.errInvalidWrite
|
||
|
errInvalidWrite error
|
||
|
)
|
||
|
|
||
|
// CopyPool provides a memory pool of byte
|
||
|
// buffers for io copies from readers to writers.
|
||
|
type CopyPool struct {
|
||
|
size int
|
||
|
pool sync.Pool
|
||
|
}
|
||
|
|
||
|
// See CopyPool.Buffer().
|
||
|
func Buffer(sz int) int {
|
||
|
return pool.Buffer(sz)
|
||
|
}
|
||
|
|
||
|
// See CopyPool.CopyN().
|
||
|
func CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) {
|
||
|
return pool.CopyN(dst, src, n)
|
||
|
}
|
||
|
|
||
|
// See CopyPool.Copy().
|
||
|
func Copy(dst io.Writer, src io.Reader) (int64, error) {
|
||
|
return pool.Copy(dst, src)
|
||
|
}
|
||
|
|
||
|
// Buffer sets the pool buffer size to allocate. Returns current size.
|
||
|
// Note this is NOT atomically safe, please call BEFORE other calls to CopyPool.
|
||
|
func (cp *CopyPool) Buffer(sz int) int {
|
||
|
if sz > 0 {
|
||
|
// update size
|
||
|
cp.size = sz
|
||
|
} else if cp.size < 1 {
|
||
|
// default size
|
||
|
return 4096
|
||
|
}
|
||
|
return cp.size
|
||
|
}
|
||
|
|
||
|
// CopyN performs the same logic as io.CopyN(), with the difference
|
||
|
// being that the byte buffer is acquired from a memory pool.
|
||
|
func (cp *CopyPool) CopyN(dst io.Writer, src io.Reader, n int64) (int64, error) {
|
||
|
written, err := cp.Copy(dst, io.LimitReader(src, n))
|
||
|
if written == n {
|
||
|
return n, nil
|
||
|
}
|
||
|
if written < n && err == nil {
|
||
|
// src stopped early; must have been EOF.
|
||
|
err = io.EOF
|
||
|
}
|
||
|
return written, err
|
||
|
}
|
||
|
|
||
|
// Copy performs the same logic as io.Copy(), with the difference
|
||
|
// being that the byte buffer is acquired from a memory pool.
|
||
|
func (cp *CopyPool) Copy(dst io.Writer, src io.Reader) (int64, error) {
|
||
|
// Prefer using io.WriterTo to do the copy (avoids alloc + copy)
|
||
|
if wt, ok := src.(io.WriterTo); ok {
|
||
|
return wt.WriteTo(dst)
|
||
|
}
|
||
|
|
||
|
// Prefer using io.ReaderFrom to do the copy.
|
||
|
if rt, ok := dst.(io.ReaderFrom); ok {
|
||
|
return rt.ReadFrom(src)
|
||
|
}
|
||
|
|
||
|
var buf []byte
|
||
|
|
||
|
if b, ok := cp.pool.Get().([]byte); ok {
|
||
|
// Acquired buf from pool
|
||
|
buf = b
|
||
|
} else {
|
||
|
// Allocate new buffer of size
|
||
|
buf = make([]byte, cp.Buffer(0))
|
||
|
}
|
||
|
|
||
|
// Defer release to pool
|
||
|
defer cp.pool.Put(buf)
|
||
|
|
||
|
var n int64
|
||
|
for {
|
||
|
// Perform next read into buf
|
||
|
nr, err := src.Read(buf)
|
||
|
if nr > 0 {
|
||
|
// We error check AFTER checking
|
||
|
// no. read bytes so incomplete
|
||
|
// read still gets written up to nr.
|
||
|
|
||
|
// Perform next write from buf
|
||
|
nw, ew := dst.Write(buf[0:nr])
|
||
|
|
||
|
// Check for valid write
|
||
|
if nw < 0 || nr < nw {
|
||
|
if ew == nil {
|
||
|
ew = errInvalidWrite
|
||
|
}
|
||
|
return n, ew
|
||
|
}
|
||
|
|
||
|
// Incr total count
|
||
|
n += int64(nw)
|
||
|
|
||
|
// Check write error
|
||
|
if ew != nil {
|
||
|
return n, ew
|
||
|
}
|
||
|
|
||
|
// Check unequal read/writes
|
||
|
if nr != nw {
|
||
|
return n, io.ErrShortWrite
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Return on err
|
||
|
if err != nil {
|
||
|
if err == io.EOF {
|
||
|
err = nil // expected
|
||
|
}
|
||
|
return n, err
|
||
|
}
|
||
|
}
|
||
|
}
|