mirror of
1
Fork 0
forgejo/modules/zstd/zstd.go

164 lines
3.5 KiB
Go

// Copyright 2024 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
// Package zstd provides a high-level API for reading and writing zstd-compressed data.
// It supports both regular and seekable zstd streams.
// It's not a new wheel, but a wrapper around the zstd and zstd-seekable-format-go packages.
package zstd
import (
"errors"
"io"
seekable "github.com/SaveTheRbtz/zstd-seekable-format-go/pkg"
"github.com/klauspost/compress/zstd"
)
type Writer zstd.Encoder
var _ io.WriteCloser = (*Writer)(nil)
// NewWriter returns a new zstd writer.
func NewWriter(w io.Writer, opts ...WriterOption) (*Writer, error) {
zstdW, err := zstd.NewWriter(w, opts...)
if err != nil {
return nil, err
}
return (*Writer)(zstdW), nil
}
func (w *Writer) Write(p []byte) (int, error) {
return (*zstd.Encoder)(w).Write(p)
}
func (w *Writer) Close() error {
return (*zstd.Encoder)(w).Close()
}
type Reader zstd.Decoder
var _ io.ReadCloser = (*Reader)(nil)
// NewReader returns a new zstd reader.
func NewReader(r io.Reader, opts ...ReaderOption) (*Reader, error) {
zstdR, err := zstd.NewReader(r, opts...)
if err != nil {
return nil, err
}
return (*Reader)(zstdR), nil
}
func (r *Reader) Read(p []byte) (int, error) {
return (*zstd.Decoder)(r).Read(p)
}
func (r *Reader) Close() error {
(*zstd.Decoder)(r).Close() // no error returned
return nil
}
type SeekableWriter struct {
buf []byte
n int
w seekable.Writer
}
var _ io.WriteCloser = (*SeekableWriter)(nil)
// NewSeekableWriter returns a zstd writer to compress data to seekable format.
// blockSize is an important parameter, it should be decided according to the actual business requirements.
// If it's too small, the compression ratio could be very bad, even no compression at all.
// If it's too large, it could cost more traffic when reading the data partially from underlying storage.
func NewSeekableWriter(w io.Writer, blockSize int, opts ...WriterOption) (*SeekableWriter, error) {
zstdW, err := zstd.NewWriter(nil, opts...)
if err != nil {
return nil, err
}
seekableW, err := seekable.NewWriter(w, zstdW)
if err != nil {
return nil, err
}
return &SeekableWriter{
buf: make([]byte, blockSize),
w: seekableW,
}, nil
}
func (w *SeekableWriter) Write(p []byte) (int, error) {
written := 0
for len(p) > 0 {
n := copy(w.buf[w.n:], p)
w.n += n
written += n
p = p[n:]
if w.n == len(w.buf) {
if _, err := w.w.Write(w.buf); err != nil {
return written, err
}
w.n = 0
}
}
return written, nil
}
func (w *SeekableWriter) Close() error {
if w.n > 0 {
if _, err := w.w.Write(w.buf[:w.n]); err != nil {
return err
}
}
return w.w.Close()
}
type SeekableReader struct {
r seekable.Reader
c func() error
}
var _ io.ReadSeekCloser = (*SeekableReader)(nil)
// NewSeekableReader returns a zstd reader to decompress data from seekable format.
func NewSeekableReader(r io.ReadSeeker, opts ...ReaderOption) (*SeekableReader, error) {
zstdR, err := zstd.NewReader(nil, opts...)
if err != nil {
return nil, err
}
seekableR, err := seekable.NewReader(r, zstdR)
if err != nil {
return nil, err
}
ret := &SeekableReader{
r: seekableR,
}
if closer, ok := r.(io.Closer); ok {
ret.c = closer.Close
}
return ret, nil
}
func (r *SeekableReader) Read(p []byte) (int, error) {
return r.r.Read(p)
}
func (r *SeekableReader) Seek(offset int64, whence int) (int64, error) {
return r.r.Seek(offset, whence)
}
func (r *SeekableReader) Close() error {
return errors.Join(
func() error {
if r.c != nil {
return r.c()
}
return nil
}(),
r.r.Close(),
)
}