[feature] add rate limit middleware (#741)
* feat: add rate limit middleware * chore: update vendor dir * chore: update readme with new dependency * chore: add rate limit infos to swagger.md file * refactor: add ipv6 mask limiter option Add IPv6 CIDR /64 mask * refactor: increase rate limit to 1000 Address https://github.com/superseriousbusiness/gotosocial/pull/741#discussion_r945584800 Co-authored-by: tobi <31960611+tsmethurst@users.noreply.github.com>
This commit is contained in:
parent
daec9ab10e
commit
bee8458a2d
|
@ -248,6 +248,7 @@ The following libraries and frameworks are used by GoToSocial, with gratitude
|
||||||
- [tdewolff/minify](https://github.com/tdewolff/minify); HTML minification for Markdown-submitted posts. [MIT License](https://spdx.org/licenses/MIT.html).
|
- [tdewolff/minify](https://github.com/tdewolff/minify); HTML minification for Markdown-submitted posts. [MIT License](https://spdx.org/licenses/MIT.html).
|
||||||
- [uptrace/bun](https://github.com/uptrace/bun); database ORM. [BSD-2-Clause License](https://spdx.org/licenses/BSD-2-Clause.html).
|
- [uptrace/bun](https://github.com/uptrace/bun); database ORM. [BSD-2-Clause License](https://spdx.org/licenses/BSD-2-Clause.html).
|
||||||
- [wagslane/go-password-validator](https://github.com/wagslane/go-password-validator); password strength validation. [MIT License](https://spdx.org/licenses/MIT.html).
|
- [wagslane/go-password-validator](https://github.com/wagslane/go-password-validator); password strength validation. [MIT License](https://spdx.org/licenses/MIT.html).
|
||||||
|
- [ulule/limiter](https://github.com/ulule/limiter); http rate limit middleware. [MIT License](https://spdx.org/licenses/MIT.html).
|
||||||
|
|
||||||
### Image Attribution
|
### Image Attribution
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,16 @@
|
||||||
# API Documentation
|
# API Documentation
|
||||||
|
|
||||||
|
## Rate limit
|
||||||
|
|
||||||
|
To prevent abuse of the API an IP-based HTTP rate limit is in place, a maximum of 300 requests in a 5 minutes time window are allowed, every response will include the current status of the rate limit with the following headers:
|
||||||
|
|
||||||
|
- `x-ratelimit-limit` maximum number of requests allowed per time period (fixed)
|
||||||
|
- `x-ratelimit-remaining` number of remaining requests that can still be performed
|
||||||
|
- `x-ratelimit-reset` unix timestamp when the rate limit will reset
|
||||||
|
|
||||||
|
In case the rate limit is exceeded an HTTP 429 error is returned to the caller.
|
||||||
|
|
||||||
|
|
||||||
GoToSocial uses [go-swagger](https://github.com/go-swagger/go-swagger) to generate a V2 [OpenAPI specification](https://swagger.io/specification/v2/) document from code annotations.
|
GoToSocial uses [go-swagger](https://github.com/go-swagger/go-swagger) to generate a V2 [OpenAPI specification](https://swagger.io/specification/v2/) document from code annotations.
|
||||||
|
|
||||||
The resulting API documentation is rendered below, for quick reference.
|
The resulting API documentation is rendered below, for quick reference.
|
||||||
|
|
4
go.mod
4
go.mod
|
@ -41,6 +41,7 @@ require (
|
||||||
github.com/superseriousbusiness/activity v1.1.0-gts
|
github.com/superseriousbusiness/activity v1.1.0-gts
|
||||||
github.com/superseriousbusiness/exif-terminator v0.4.0
|
github.com/superseriousbusiness/exif-terminator v0.4.0
|
||||||
github.com/superseriousbusiness/oauth2/v4 v4.3.2-SSB
|
github.com/superseriousbusiness/oauth2/v4 v4.3.2-SSB
|
||||||
|
github.com/ulule/limiter/v3 v3.10.0
|
||||||
github.com/tdewolff/minify/v2 v2.12.0
|
github.com/tdewolff/minify/v2 v2.12.0
|
||||||
github.com/uptrace/bun v1.1.7
|
github.com/uptrace/bun v1.1.7
|
||||||
github.com/uptrace/bun/dialect/pgdialect v1.1.7
|
github.com/uptrace/bun/dialect/pgdialect v1.1.7
|
||||||
|
@ -98,7 +99,7 @@ require (
|
||||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
|
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
|
||||||
github.com/klauspost/compress v1.13.6 // indirect
|
github.com/klauspost/compress v1.15.0 // indirect
|
||||||
github.com/klauspost/cpuid v1.3.1 // indirect
|
github.com/klauspost/cpuid v1.3.1 // indirect
|
||||||
github.com/leodido/go-urn v1.2.1 // indirect
|
github.com/leodido/go-urn v1.2.1 // indirect
|
||||||
github.com/magiconair/properties v1.8.6 // indirect
|
github.com/magiconair/properties v1.8.6 // indirect
|
||||||
|
@ -110,6 +111,7 @@ require (
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/pelletier/go-toml v1.9.5 // indirect
|
github.com/pelletier/go-toml v1.9.5 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.0.0 // indirect
|
github.com/pelletier/go-toml/v2 v2.0.0 // indirect
|
||||||
|
github.com/pkg/errors v0.9.1 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect
|
github.com/quasoft/memstore v0.0.0-20191010062613-2bce066d2b0b // indirect
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect
|
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect
|
||||||
|
|
10
go.sum
10
go.sum
|
@ -95,8 +95,8 @@ github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030I
|
||||||
github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs=
|
github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs=
|
||||||
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
|
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
|
||||||
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
|
||||||
github.com/andybalholm/brotli v1.0.0 h1:7UCwP93aiSfvWpapti8g88vVVGp2qqtGyePsSuDafo4=
|
|
||||||
github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
|
github.com/andybalholm/brotli v1.0.0/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y=
|
||||||
|
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
|
||||||
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
|
||||||
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
|
||||||
github.com/buckket/go-blurhash v1.1.0 h1:X5M6r0LIvwdvKiUtiNcRL2YlmOfMzYobI3VCKCZc9Do=
|
github.com/buckket/go-blurhash v1.1.0 h1:X5M6r0LIvwdvKiUtiNcRL2YlmOfMzYobI3VCKCZc9Do=
|
||||||
|
@ -372,8 +372,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:C
|
||||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
github.com/klauspost/compress v1.10.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
github.com/klauspost/compress v1.10.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||||
github.com/klauspost/compress v1.10.10/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
github.com/klauspost/compress v1.10.10/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
|
||||||
github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc=
|
github.com/klauspost/compress v1.15.0 h1:xqfchp4whNFxn5A4XFyyYtitiWI8Hy5EW59jEwcyL6U=
|
||||||
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
|
github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
|
||||||
github.com/klauspost/cpuid v1.2.3/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
|
github.com/klauspost/cpuid v1.2.3/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek=
|
||||||
github.com/klauspost/cpuid v1.3.1 h1:5JNjFYYQrZeKRJ0734q51WCEEn2huer72Dc7K+R/b6s=
|
github.com/klauspost/cpuid v1.3.1 h1:5JNjFYYQrZeKRJ0734q51WCEEn2huer72Dc7K+R/b6s=
|
||||||
github.com/klauspost/cpuid v1.3.1/go.mod h1:bYW4mA6ZgKPob1/Dlai2LviZJO7KGI3uoWLd42rAQw4=
|
github.com/klauspost/cpuid v1.3.1/go.mod h1:bYW4mA6ZgKPob1/Dlai2LviZJO7KGI3uoWLd42rAQw4=
|
||||||
|
@ -549,6 +549,8 @@ github.com/ugorji/go v1.2.7/go.mod h1:nF9osbDWLy6bDVv/Rtoh6QgnvNDpmCalQV5urGCCS6
|
||||||
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
|
github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY=
|
||||||
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
|
github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0=
|
||||||
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY=
|
||||||
|
github.com/ulule/limiter/v3 v3.10.0 h1:C9mx3tgxYnt4pUYKWktZf7aEOVPbRYxR+onNFjQTEp0=
|
||||||
|
github.com/ulule/limiter/v3 v3.10.0/go.mod h1:NqPA/r8QfP7O11iC+95X6gcWJPtRWjKrtOUw07BTvoo=
|
||||||
github.com/uptrace/bun v1.1.7 h1:biOoh5dov69hQPBlaRsXSHoEOIEnCxFzQvUmbscSNJI=
|
github.com/uptrace/bun v1.1.7 h1:biOoh5dov69hQPBlaRsXSHoEOIEnCxFzQvUmbscSNJI=
|
||||||
github.com/uptrace/bun v1.1.7/go.mod h1:Z2Pd3cRvNKbrYuL6Gp1XGjA9QEYz+rDz5KkEi9MZLnQ=
|
github.com/uptrace/bun v1.1.7/go.mod h1:Z2Pd3cRvNKbrYuL6Gp1XGjA9QEYz+rDz5KkEi9MZLnQ=
|
||||||
github.com/uptrace/bun/dialect/pgdialect v1.1.7 h1:94GPc8RRC9AVoQ+4KCqRX2zScevsVfOttk13wm60/P8=
|
github.com/uptrace/bun/dialect/pgdialect v1.1.7 h1:94GPc8RRC9AVoQ+4KCqRX2zScevsVfOttk13wm60/P8=
|
||||||
|
@ -557,8 +559,8 @@ github.com/uptrace/bun/dialect/sqlitedialect v1.1.7 h1:xxc1n1nUdn6zqY6ji1ZkiaHQy
|
||||||
github.com/uptrace/bun/dialect/sqlitedialect v1.1.7/go.mod h1:GjqiPWAa9JCLlv51mB1rjk8QRgwv6HlQ+IAtyrobfAY=
|
github.com/uptrace/bun/dialect/sqlitedialect v1.1.7/go.mod h1:GjqiPWAa9JCLlv51mB1rjk8QRgwv6HlQ+IAtyrobfAY=
|
||||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||||
github.com/valyala/fasthttp v1.14.0 h1:67bfuW9azCMwW/Jlq/C+VeihNpAuJMWkYPBig1gdi3A=
|
|
||||||
github.com/valyala/fasthttp v1.14.0/go.mod h1:ol1PCaL0dX20wC0htZ7sYCsvCYmrouYra0zHzaclZhE=
|
github.com/valyala/fasthttp v1.14.0/go.mod h1:ol1PCaL0dX20wC0htZ7sYCsvCYmrouYra0zHzaclZhE=
|
||||||
|
github.com/valyala/fasthttp v1.34.0 h1:d3AAQJ2DRcxJYHm7OXNXtXt2as1vMDfxeIcFvhmGGm4=
|
||||||
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
|
github.com/valyala/tcplisten v0.0.0-20161114210144-ceec8f93295a/go.mod h1:v3UYOV9WzVtRmSR+PDvWpU/qWl4Wa5LApYYX4ZtKbio=
|
||||||
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
|
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
|
||||||
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
|
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
/*
|
||||||
|
GoToSocial
|
||||||
|
Copyright (C) 2021-2022 GoToSocial Authors admin@gotosocial.org
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU Affero General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU Affero General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU Affero General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package security
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
limiter "github.com/ulule/limiter/v3"
|
||||||
|
mgin "github.com/ulule/limiter/v3/drivers/middleware/gin"
|
||||||
|
memory "github.com/ulule/limiter/v3/drivers/store/memory"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RateLimitOptions struct {
|
||||||
|
Period time.Duration
|
||||||
|
Limit int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Module) LimitReachedHandler(c *gin.Context) {
|
||||||
|
code := http.StatusTooManyRequests
|
||||||
|
c.AbortWithStatusJSON(code, gin.H{"error": "rate limit reached"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns a gin middleware that will automatically rate limit caller (by IP address)
|
||||||
|
// and enrich the response header with the following headers:
|
||||||
|
// - `x-ratelimit-limit` maximum number of requests allowed per time period (fixed)
|
||||||
|
// - `x-ratelimit-remaining` number of remaining requests that can still be performed
|
||||||
|
// - `x-ratelimit-reset` unix timestamp when the rate limit will reset
|
||||||
|
// if `x-ratelimit-limit` is exceeded an HTTP 429 error is returned
|
||||||
|
func (m *Module) RateLimit(rateOptions RateLimitOptions) func(c *gin.Context) {
|
||||||
|
rate := limiter.Rate{
|
||||||
|
Period: rateOptions.Period,
|
||||||
|
Limit: rateOptions.Limit,
|
||||||
|
}
|
||||||
|
|
||||||
|
store := memory.NewStore()
|
||||||
|
|
||||||
|
limiterInstance := limiter.New(
|
||||||
|
store,
|
||||||
|
rate,
|
||||||
|
// apply /64 mask to IPv6 addresses
|
||||||
|
limiter.WithIPv6Mask(net.CIDRMask(64, 128)),
|
||||||
|
)
|
||||||
|
|
||||||
|
middleware := mgin.NewMiddleware(
|
||||||
|
limiterInstance,
|
||||||
|
// use custom rate limit reached error
|
||||||
|
mgin.WithLimitReachedHandler(m.LimitReachedHandler),
|
||||||
|
)
|
||||||
|
|
||||||
|
return middleware
|
||||||
|
}
|
|
@ -20,6 +20,7 @@ package security
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/api"
|
"github.com/superseriousbusiness/gotosocial/internal/api"
|
||||||
"github.com/superseriousbusiness/gotosocial/internal/db"
|
"github.com/superseriousbusiness/gotosocial/internal/db"
|
||||||
|
@ -45,6 +46,11 @@ func New(db db.DB, server oauth.Server) api.ClientModule {
|
||||||
|
|
||||||
// Route attaches security middleware to the given router
|
// Route attaches security middleware to the given router
|
||||||
func (m *Module) Route(s router.Router) error {
|
func (m *Module) Route(s router.Router) error {
|
||||||
|
s.AttachMiddleware(m.RateLimit(RateLimitOptions{
|
||||||
|
// accept a maximum of 1000 requests in 5 minutes window
|
||||||
|
Period: 5 * time.Minute,
|
||||||
|
Limit: 1000,
|
||||||
|
}))
|
||||||
s.AttachMiddleware(m.SignatureCheck)
|
s.AttachMiddleware(m.SignatureCheck)
|
||||||
s.AttachMiddleware(m.FlocBlock)
|
s.AttachMiddleware(m.FlocBlock)
|
||||||
s.AttachMiddleware(m.ExtraHeaders)
|
s.AttachMiddleware(m.ExtraHeaders)
|
||||||
|
|
|
@ -20,6 +20,7 @@ This is important, so you don't have to worry about spending CPU cycles on alrea
|
||||||
* Concurrent stream compression
|
* Concurrent stream compression
|
||||||
* Faster decompression, even for Snappy compatible content
|
* Faster decompression, even for Snappy compatible content
|
||||||
* Ability to quickly skip forward in compressed stream
|
* Ability to quickly skip forward in compressed stream
|
||||||
|
* Random seeking with indexes
|
||||||
* Compatible with reading Snappy compressed content
|
* Compatible with reading Snappy compressed content
|
||||||
* Smaller block size overhead on incompressible blocks
|
* Smaller block size overhead on incompressible blocks
|
||||||
* Block concatenation
|
* Block concatenation
|
||||||
|
@ -29,8 +30,8 @@ This is important, so you don't have to worry about spending CPU cycles on alrea
|
||||||
|
|
||||||
## Drawbacks over Snappy
|
## Drawbacks over Snappy
|
||||||
|
|
||||||
* Not optimized for 32 bit systems.
|
* Not optimized for 32 bit systems
|
||||||
* Streams use slightly more memory due to larger blocks and concurrency (configurable).
|
* Streams use slightly more memory due to larger blocks and concurrency (configurable)
|
||||||
|
|
||||||
# Usage
|
# Usage
|
||||||
|
|
||||||
|
@ -141,7 +142,7 @@ Binaries can be downloaded on the [Releases Page](https://github.com/klauspost/c
|
||||||
|
|
||||||
Installing then requires Go to be installed. To install them, use:
|
Installing then requires Go to be installed. To install them, use:
|
||||||
|
|
||||||
`go install github.com/klauspost/compress/s2/cmd/s2c && go install github.com/klauspost/compress/s2/cmd/s2d`
|
`go install github.com/klauspost/compress/s2/cmd/s2c@latest && go install github.com/klauspost/compress/s2/cmd/s2d@latest`
|
||||||
|
|
||||||
To build binaries to the current folder use:
|
To build binaries to the current folder use:
|
||||||
|
|
||||||
|
@ -176,6 +177,8 @@ Options:
|
||||||
Compress faster, but with a minor compression loss
|
Compress faster, but with a minor compression loss
|
||||||
-help
|
-help
|
||||||
Display help
|
Display help
|
||||||
|
-index
|
||||||
|
Add seek index (default true)
|
||||||
-o string
|
-o string
|
||||||
Write output to another file. Single input file only
|
Write output to another file. Single input file only
|
||||||
-pad string
|
-pad string
|
||||||
|
@ -217,11 +220,15 @@ Options:
|
||||||
Display help
|
Display help
|
||||||
-o string
|
-o string
|
||||||
Write output to another file. Single input file only
|
Write output to another file. Single input file only
|
||||||
|
-offset string
|
||||||
|
Start at offset. Examples: 92, 64K, 256K, 1M, 4M. Requires Index
|
||||||
-q Don't write any output to terminal, except errors
|
-q Don't write any output to terminal, except errors
|
||||||
-rm
|
-rm
|
||||||
Delete source file(s) after successful decompression
|
Delete source file(s) after successful decompression
|
||||||
-safe
|
-safe
|
||||||
Do not overwrite output files
|
Do not overwrite output files
|
||||||
|
-tail string
|
||||||
|
Return last of compressed file. Examples: 92, 64K, 256K, 1M, 4M. Requires Index
|
||||||
-verify
|
-verify
|
||||||
Verify files, but do not write output
|
Verify files, but do not write output
|
||||||
```
|
```
|
||||||
|
@ -634,11 +641,11 @@ Comparison of [`webdevdata.org-2015-01-07-subset`](https://files.klauspost.com/c
|
||||||
53927 files, total input size: 4,014,735,833 bytes. amd64, single goroutine used:
|
53927 files, total input size: 4,014,735,833 bytes. amd64, single goroutine used:
|
||||||
|
|
||||||
| Encoder | Size | MB/s | Reduction |
|
| Encoder | Size | MB/s | Reduction |
|
||||||
|-----------------------|------------|--------|------------
|
|-----------------------|------------|------------|------------
|
||||||
| snappy.Encode | 1128706759 | 725.59 | 71.89% |
|
| snappy.Encode | 1128706759 | 725.59 | 71.89% |
|
||||||
| s2.EncodeSnappy | 1093823291 | 899.16 | 72.75% |
|
| s2.EncodeSnappy | 1093823291 | **899.16** | 72.75% |
|
||||||
| s2.EncodeSnappyBetter | 1001158548 | 578.49 | 75.06% |
|
| s2.EncodeSnappyBetter | 1001158548 | 578.49 | 75.06% |
|
||||||
| s2.EncodeSnappyBest | 944507998 | 66.00 | 76.47% |
|
| s2.EncodeSnappyBest | 944507998 | 66.00 | **76.47%**|
|
||||||
|
|
||||||
## Streams
|
## Streams
|
||||||
|
|
||||||
|
@ -649,11 +656,11 @@ Comparison of different streams, AMD Ryzen 3950x, 16 cores. Size and throughput:
|
||||||
|
|
||||||
| File | snappy.NewWriter | S2 Snappy | S2 Snappy, Better | S2 Snappy, Best |
|
| File | snappy.NewWriter | S2 Snappy | S2 Snappy, Better | S2 Snappy, Best |
|
||||||
|-----------------------------|--------------------------|---------------------------|--------------------------|-------------------------|
|
|-----------------------------|--------------------------|---------------------------|--------------------------|-------------------------|
|
||||||
| nyc-taxi-data-10M.csv | 1316042016 - 517.54MB/s | 1307003093 - 8406.29MB/s | 1174534014 - 4984.35MB/s | 1115904679 - 177.81MB/s |
|
| nyc-taxi-data-10M.csv | 1316042016 - 539.47MB/s | 1307003093 - 10132.73MB/s | 1174534014 - 5002.44MB/s | 1115904679 - 177.97MB/s |
|
||||||
| enwik10 | 5088294643 - 433.45MB/s | 5175840939 - 8454.52MB/s | 4560784526 - 4403.10MB/s | 4340299103 - 159.71MB/s |
|
| enwik10 (xml) | 5088294643 - 451.13MB/s | 5175840939 - 9440.69MB/s | 4560784526 - 4487.21MB/s | 4340299103 - 158.92MB/s |
|
||||||
| 10gb.tar | 6056946612 - 703.25MB/s | 6208571995 - 9035.75MB/s | 5741646126 - 2402.08MB/s | 5548973895 - 171.17MB/s |
|
| 10gb.tar (mixed) | 6056946612 - 729.73MB/s | 6208571995 - 9978.05MB/s | 5741646126 - 4919.98MB/s | 5548973895 - 180.44MB/s |
|
||||||
| github-june-2days-2019.json | 1525176492 - 908.11MB/s | 1476519054 - 12625.93MB/s | 1400547532 - 6163.61MB/s | 1321887137 - 200.71MB/s |
|
| github-june-2days-2019.json | 1525176492 - 933.00MB/s | 1476519054 - 13150.12MB/s | 1400547532 - 5803.40MB/s | 1321887137 - 204.29MB/s |
|
||||||
| consensus.db.10gb | 5412897703 - 1054.38MB/s | 5354073487 - 12634.82MB/s | 5335069899 - 2472.23MB/s | 5201000954 - 166.32MB/s |
|
| consensus.db.10gb (db) | 5412897703 - 1102.14MB/s | 5354073487 - 13562.91MB/s | 5335069899 - 5294.73MB/s | 5201000954 - 175.72MB/s |
|
||||||
|
|
||||||
# Decompression
|
# Decompression
|
||||||
|
|
||||||
|
@ -680,6 +687,219 @@ The 10 byte 'stream identifier' of the second stream can optionally be stripped,
|
||||||
Blocks can be concatenated using the `ConcatBlocks` function.
|
Blocks can be concatenated using the `ConcatBlocks` function.
|
||||||
|
|
||||||
Snappy blocks/streams can safely be concatenated with S2 blocks and streams.
|
Snappy blocks/streams can safely be concatenated with S2 blocks and streams.
|
||||||
|
Streams with indexes (see below) will currently not work on concatenated streams.
|
||||||
|
|
||||||
|
# Stream Seek Index
|
||||||
|
|
||||||
|
S2 and Snappy streams can have indexes. These indexes will allow random seeking within the compressed data.
|
||||||
|
|
||||||
|
The index can either be appended to the stream as a skippable block or returned for separate storage.
|
||||||
|
|
||||||
|
When the index is appended to a stream it will be skipped by regular decoders,
|
||||||
|
so the output remains compatible with other decoders.
|
||||||
|
|
||||||
|
## Creating an Index
|
||||||
|
|
||||||
|
To automatically add an index to a stream, add `WriterAddIndex()` option to your writer.
|
||||||
|
Then the index will be added to the stream when `Close()` is called.
|
||||||
|
|
||||||
|
```
|
||||||
|
// Add Index to stream...
|
||||||
|
enc := s2.NewWriter(w, s2.WriterAddIndex())
|
||||||
|
io.Copy(enc, r)
|
||||||
|
enc.Close()
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to store the index separately, you can use `CloseIndex()` instead of the regular `Close()`.
|
||||||
|
This will return the index. Note that `CloseIndex()` should only be called once, and you shouldn't call `Close()`.
|
||||||
|
|
||||||
|
```
|
||||||
|
// Get index for separate storage...
|
||||||
|
enc := s2.NewWriter(w)
|
||||||
|
io.Copy(enc, r)
|
||||||
|
index, err := enc.CloseIndex()
|
||||||
|
```
|
||||||
|
|
||||||
|
The `index` can then be used needing to read from the stream.
|
||||||
|
This means the index can be used without needing to seek to the end of the stream
|
||||||
|
or for manually forwarding streams. See below.
|
||||||
|
|
||||||
|
Finally, an existing S2/Snappy stream can be indexed using the `s2.IndexStream(r io.Reader)` function.
|
||||||
|
|
||||||
|
## Using Indexes
|
||||||
|
|
||||||
|
To use indexes there is a `ReadSeeker(random bool, index []byte) (*ReadSeeker, error)` function available.
|
||||||
|
|
||||||
|
Calling ReadSeeker will return an [io.ReadSeeker](https://pkg.go.dev/io#ReadSeeker) compatible version of the reader.
|
||||||
|
|
||||||
|
If 'random' is specified the returned io.Seeker can be used for random seeking, otherwise only forward seeking is supported.
|
||||||
|
Enabling random seeking requires the original input to support the [io.Seeker](https://pkg.go.dev/io#Seeker) interface.
|
||||||
|
|
||||||
|
```
|
||||||
|
dec := s2.NewReader(r)
|
||||||
|
rs, err := dec.ReadSeeker(false, nil)
|
||||||
|
rs.Seek(wantOffset, io.SeekStart)
|
||||||
|
```
|
||||||
|
|
||||||
|
Get a seeker to seek forward. Since no index is provided, the index is read from the stream.
|
||||||
|
This requires that an index was added and that `r` supports the [io.Seeker](https://pkg.go.dev/io#Seeker) interface.
|
||||||
|
|
||||||
|
A custom index can be specified which will be used if supplied.
|
||||||
|
When using a custom index, it will not be read from the input stream.
|
||||||
|
|
||||||
|
```
|
||||||
|
dec := s2.NewReader(r)
|
||||||
|
rs, err := dec.ReadSeeker(false, index)
|
||||||
|
rs.Seek(wantOffset, io.SeekStart)
|
||||||
|
```
|
||||||
|
|
||||||
|
This will read the index from `index`. Since we specify non-random (forward only) seeking `r` does not have to be an io.Seeker
|
||||||
|
|
||||||
|
```
|
||||||
|
dec := s2.NewReader(r)
|
||||||
|
rs, err := dec.ReadSeeker(true, index)
|
||||||
|
rs.Seek(wantOffset, io.SeekStart)
|
||||||
|
```
|
||||||
|
|
||||||
|
Finally, since we specify that we want to do random seeking `r` must be an io.Seeker.
|
||||||
|
|
||||||
|
The returned [ReadSeeker](https://pkg.go.dev/github.com/klauspost/compress/s2#ReadSeeker) contains a shallow reference to the existing Reader,
|
||||||
|
meaning changes performed to one is reflected in the other.
|
||||||
|
|
||||||
|
To check if a stream contains an index at the end, the `(*Index).LoadStream(rs io.ReadSeeker) error` can be used.
|
||||||
|
|
||||||
|
## Manually Forwarding Streams
|
||||||
|
|
||||||
|
Indexes can also be read outside the decoder using the [Index](https://pkg.go.dev/github.com/klauspost/compress/s2#Index) type.
|
||||||
|
This can be used for parsing indexes, either separate or in streams.
|
||||||
|
|
||||||
|
In some cases it may not be possible to serve a seekable stream.
|
||||||
|
This can for instance be an HTTP stream, where the Range request
|
||||||
|
is sent at the start of the stream.
|
||||||
|
|
||||||
|
With a little bit of extra code it is still possible to use indexes
|
||||||
|
to forward to specific offset with a single forward skip.
|
||||||
|
|
||||||
|
It is possible to load the index manually like this:
|
||||||
|
```
|
||||||
|
var index s2.Index
|
||||||
|
_, err = index.Load(idxBytes)
|
||||||
|
```
|
||||||
|
|
||||||
|
This can be used to figure out how much to offset the compressed stream:
|
||||||
|
|
||||||
|
```
|
||||||
|
compressedOffset, uncompressedOffset, err := index.Find(wantOffset)
|
||||||
|
```
|
||||||
|
|
||||||
|
The `compressedOffset` is the number of bytes that should be skipped
|
||||||
|
from the beginning of the compressed file.
|
||||||
|
|
||||||
|
The `uncompressedOffset` will then be offset of the uncompressed bytes returned
|
||||||
|
when decoding from that position. This will always be <= wantOffset.
|
||||||
|
|
||||||
|
When creating a decoder it must be specified that it should *not* expect a stream identifier
|
||||||
|
at the beginning of the stream. Assuming the io.Reader `r` has been forwarded to `compressedOffset`
|
||||||
|
we create the decoder like this:
|
||||||
|
|
||||||
|
```
|
||||||
|
dec := s2.NewReader(r, s2.ReaderIgnoreStreamIdentifier())
|
||||||
|
```
|
||||||
|
|
||||||
|
We are not completely done. We still need to forward the stream the uncompressed bytes we didn't want.
|
||||||
|
This is done using the regular "Skip" function:
|
||||||
|
|
||||||
|
```
|
||||||
|
err = dec.Skip(wantOffset - uncompressedOffset)
|
||||||
|
```
|
||||||
|
|
||||||
|
This will ensure that we are at exactly the offset we want, and reading from `dec` will start at the requested offset.
|
||||||
|
|
||||||
|
## Index Format:
|
||||||
|
|
||||||
|
Each block is structured as a snappy skippable block, with the chunk ID 0x99.
|
||||||
|
|
||||||
|
The block can be read from the front, but contains information so it can be read from the back as well.
|
||||||
|
|
||||||
|
Numbers are stored as fixed size little endian values or [zigzag encoded](https://developers.google.com/protocol-buffers/docs/encoding#signed_integers) [base 128 varints](https://developers.google.com/protocol-buffers/docs/encoding),
|
||||||
|
with un-encoded value length of 64 bits, unless other limits are specified.
|
||||||
|
|
||||||
|
| Content | Format |
|
||||||
|
|---------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| ID, `[1]byte` | Always 0x99. |
|
||||||
|
| Data Length, `[3]byte` | 3 byte little-endian length of the chunk in bytes, following this. |
|
||||||
|
| Header `[6]byte` | Header, must be `[115, 50, 105, 100, 120, 0]` or in text: "s2idx\x00". |
|
||||||
|
| UncompressedSize, Varint | Total Uncompressed size. |
|
||||||
|
| CompressedSize, Varint | Total Compressed size if known. Should be -1 if unknown. |
|
||||||
|
| EstBlockSize, Varint | Block Size, used for guessing uncompressed offsets. Must be >= 0. |
|
||||||
|
| Entries, Varint | Number of Entries in index, must be < 65536 and >=0. |
|
||||||
|
| HasUncompressedOffsets `byte` | 0 if no uncompressed offsets are present, 1 if present. Other values are invalid. |
|
||||||
|
| UncompressedOffsets, [Entries]VarInt | Uncompressed offsets. See below how to decode. |
|
||||||
|
| CompressedOffsets, [Entries]VarInt | Compressed offsets. See below how to decode. |
|
||||||
|
| Block Size, `[4]byte` | Little Endian total encoded size (including header and trailer). Can be used for searching backwards to start of block. |
|
||||||
|
| Trailer `[6]byte` | Trailer, must be `[0, 120, 100, 105, 50, 115]` or in text: "\x00xdi2s". Can be used for identifying block from end of stream. |
|
||||||
|
|
||||||
|
For regular streams the uncompressed offsets are fully predictable,
|
||||||
|
so `HasUncompressedOffsets` allows to specify that compressed blocks all have
|
||||||
|
exactly `EstBlockSize` bytes of uncompressed content.
|
||||||
|
|
||||||
|
Entries *must* be in order, starting with the lowest offset,
|
||||||
|
and there *must* be no uncompressed offset duplicates.
|
||||||
|
Entries *may* point to the start of a skippable block,
|
||||||
|
but it is then not allowed to also have an entry for the next block since
|
||||||
|
that would give an uncompressed offset duplicate.
|
||||||
|
|
||||||
|
There is no requirement for all blocks to be represented in the index.
|
||||||
|
In fact there is a maximum of 65536 block entries in an index.
|
||||||
|
|
||||||
|
The writer can use any method to reduce the number of entries.
|
||||||
|
An implicit block start at 0,0 can be assumed.
|
||||||
|
|
||||||
|
### Decoding entries:
|
||||||
|
|
||||||
|
```
|
||||||
|
// Read Uncompressed entries.
|
||||||
|
// Each assumes EstBlockSize delta from previous.
|
||||||
|
for each entry {
|
||||||
|
uOff = 0
|
||||||
|
if HasUncompressedOffsets == 1 {
|
||||||
|
uOff = ReadVarInt // Read value from stream
|
||||||
|
}
|
||||||
|
|
||||||
|
// Except for the first entry, use previous values.
|
||||||
|
if entryNum == 0 {
|
||||||
|
entry[entryNum].UncompressedOffset = uOff
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uncompressed uses previous offset and adds EstBlockSize
|
||||||
|
entry[entryNum].UncompressedOffset = entry[entryNum-1].UncompressedOffset + EstBlockSize
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Guess that the first block will be 50% of uncompressed size.
|
||||||
|
// Integer truncating division must be used.
|
||||||
|
CompressGuess := EstBlockSize / 2
|
||||||
|
|
||||||
|
// Read Compressed entries.
|
||||||
|
// Each assumes CompressGuess delta from previous.
|
||||||
|
// CompressGuess is adjusted for each value.
|
||||||
|
for each entry {
|
||||||
|
cOff = ReadVarInt // Read value from stream
|
||||||
|
|
||||||
|
// Except for the first entry, use previous values.
|
||||||
|
if entryNum == 0 {
|
||||||
|
entry[entryNum].CompressedOffset = cOff
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compressed uses previous and our estimate.
|
||||||
|
entry[entryNum].CompressedOffset = entry[entryNum-1].CompressedOffset + CompressGuess + cOff
|
||||||
|
|
||||||
|
// Adjust compressed offset for next loop, integer truncating division must be used.
|
||||||
|
CompressGuess += cOff/2
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
# Format Extensions
|
# Format Extensions
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,9 @@ package s2
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -22,6 +24,16 @@ var (
|
||||||
ErrUnsupported = errors.New("s2: unsupported input")
|
ErrUnsupported = errors.New("s2: unsupported input")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrCantSeek is returned if the stream cannot be seeked.
|
||||||
|
type ErrCantSeek struct {
|
||||||
|
Reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error returns the error as string.
|
||||||
|
func (e ErrCantSeek) Error() string {
|
||||||
|
return fmt.Sprintf("s2: Can't seek because %s", e.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
// DecodedLen returns the length of the decoded block.
|
// DecodedLen returns the length of the decoded block.
|
||||||
func DecodedLen(src []byte) (int, error) {
|
func DecodedLen(src []byte) (int, error) {
|
||||||
v, _, err := decodedLen(src)
|
v, _, err := decodedLen(src)
|
||||||
|
@ -88,6 +100,7 @@ func NewReader(r io.Reader, opts ...ReaderOption) *Reader {
|
||||||
} else {
|
} else {
|
||||||
nr.buf = make([]byte, MaxEncodedLen(defaultBlockSize)+checksumSize)
|
nr.buf = make([]byte, MaxEncodedLen(defaultBlockSize)+checksumSize)
|
||||||
}
|
}
|
||||||
|
nr.readHeader = nr.ignoreStreamID
|
||||||
nr.paramsOK = true
|
nr.paramsOK = true
|
||||||
return &nr
|
return &nr
|
||||||
}
|
}
|
||||||
|
@ -131,12 +144,41 @@ func ReaderAllocBlock(blockSize int) ReaderOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReaderIgnoreStreamIdentifier will make the reader skip the expected
|
||||||
|
// stream identifier at the beginning of the stream.
|
||||||
|
// This can be used when serving a stream that has been forwarded to a specific point.
|
||||||
|
func ReaderIgnoreStreamIdentifier() ReaderOption {
|
||||||
|
return func(r *Reader) error {
|
||||||
|
r.ignoreStreamID = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReaderSkippableCB will register a callback for chuncks with the specified ID.
|
||||||
|
// ID must be a Reserved skippable chunks ID, 0x80-0xfd (inclusive).
|
||||||
|
// For each chunk with the ID, the callback is called with the content.
|
||||||
|
// Any returned non-nil error will abort decompression.
|
||||||
|
// Only one callback per ID is supported, latest sent will be used.
|
||||||
|
func ReaderSkippableCB(id uint8, fn func(r io.Reader) error) ReaderOption {
|
||||||
|
return func(r *Reader) error {
|
||||||
|
if id < 0x80 || id > 0xfd {
|
||||||
|
return fmt.Errorf("ReaderSkippableCB: Invalid id provided, must be 0x80-0xfd (inclusive)")
|
||||||
|
}
|
||||||
|
r.skippableCB[id] = fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Reader is an io.Reader that can read Snappy-compressed bytes.
|
// Reader is an io.Reader that can read Snappy-compressed bytes.
|
||||||
type Reader struct {
|
type Reader struct {
|
||||||
r io.Reader
|
r io.Reader
|
||||||
err error
|
err error
|
||||||
decoded []byte
|
decoded []byte
|
||||||
buf []byte
|
buf []byte
|
||||||
|
skippableCB [0x80]func(r io.Reader) error
|
||||||
|
blockStart int64 // Uncompressed offset at start of current.
|
||||||
|
index *Index
|
||||||
|
|
||||||
// decoded[i:j] contains decoded bytes that have not yet been passed on.
|
// decoded[i:j] contains decoded bytes that have not yet been passed on.
|
||||||
i, j int
|
i, j int
|
||||||
// maximum block size allowed.
|
// maximum block size allowed.
|
||||||
|
@ -148,6 +190,7 @@ type Reader struct {
|
||||||
readHeader bool
|
readHeader bool
|
||||||
paramsOK bool
|
paramsOK bool
|
||||||
snappyFrame bool
|
snappyFrame bool
|
||||||
|
ignoreStreamID bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensureBufferSize will ensure that the buffer can take at least n bytes.
|
// ensureBufferSize will ensure that the buffer can take at least n bytes.
|
||||||
|
@ -172,11 +215,12 @@ func (r *Reader) Reset(reader io.Reader) {
|
||||||
if !r.paramsOK {
|
if !r.paramsOK {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
r.index = nil
|
||||||
r.r = reader
|
r.r = reader
|
||||||
r.err = nil
|
r.err = nil
|
||||||
r.i = 0
|
r.i = 0
|
||||||
r.j = 0
|
r.j = 0
|
||||||
r.readHeader = false
|
r.readHeader = r.ignoreStreamID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Reader) readFull(p []byte, allowEOF bool) (ok bool) {
|
func (r *Reader) readFull(p []byte, allowEOF bool) (ok bool) {
|
||||||
|
@ -189,11 +233,24 @@ func (r *Reader) readFull(p []byte, allowEOF bool) (ok bool) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// skipN will skip n bytes.
|
// skippable will skip n bytes.
|
||||||
// If the supplied reader supports seeking that is used.
|
// If the supplied reader supports seeking that is used.
|
||||||
// tmp is used as a temporary buffer for reading.
|
// tmp is used as a temporary buffer for reading.
|
||||||
// The supplied slice does not need to be the size of the read.
|
// The supplied slice does not need to be the size of the read.
|
||||||
func (r *Reader) skipN(tmp []byte, n int, allowEOF bool) (ok bool) {
|
func (r *Reader) skippable(tmp []byte, n int, allowEOF bool, id uint8) (ok bool) {
|
||||||
|
if id < 0x80 {
|
||||||
|
r.err = fmt.Errorf("interbal error: skippable id < 0x80")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if fn := r.skippableCB[id-0x80]; fn != nil {
|
||||||
|
rd := io.LimitReader(r.r, int64(n))
|
||||||
|
r.err = fn(rd)
|
||||||
|
if r.err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, r.err = io.CopyBuffer(ioutil.Discard, rd, tmp)
|
||||||
|
return r.err == nil
|
||||||
|
}
|
||||||
if rs, ok := r.r.(io.ReadSeeker); ok {
|
if rs, ok := r.r.(io.ReadSeeker); ok {
|
||||||
_, err := rs.Seek(int64(n), io.SeekCurrent)
|
_, err := rs.Seek(int64(n), io.SeekCurrent)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
@ -247,6 +304,7 @@ func (r *Reader) Read(p []byte) (int, error) {
|
||||||
// https://github.com/google/snappy/blob/master/framing_format.txt
|
// https://github.com/google/snappy/blob/master/framing_format.txt
|
||||||
switch chunkType {
|
switch chunkType {
|
||||||
case chunkTypeCompressedData:
|
case chunkTypeCompressedData:
|
||||||
|
r.blockStart += int64(r.j)
|
||||||
// Section 4.2. Compressed data (chunk type 0x00).
|
// Section 4.2. Compressed data (chunk type 0x00).
|
||||||
if chunkLen < checksumSize {
|
if chunkLen < checksumSize {
|
||||||
r.err = ErrCorrupt
|
r.err = ErrCorrupt
|
||||||
|
@ -294,6 +352,7 @@ func (r *Reader) Read(p []byte) (int, error) {
|
||||||
continue
|
continue
|
||||||
|
|
||||||
case chunkTypeUncompressedData:
|
case chunkTypeUncompressedData:
|
||||||
|
r.blockStart += int64(r.j)
|
||||||
// Section 4.3. Uncompressed data (chunk type 0x01).
|
// Section 4.3. Uncompressed data (chunk type 0x01).
|
||||||
if chunkLen < checksumSize {
|
if chunkLen < checksumSize {
|
||||||
r.err = ErrCorrupt
|
r.err = ErrCorrupt
|
||||||
|
@ -357,17 +416,20 @@ func (r *Reader) Read(p []byte) (int, error) {
|
||||||
|
|
||||||
if chunkType <= 0x7f {
|
if chunkType <= 0x7f {
|
||||||
// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
|
// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
|
||||||
|
// fmt.Printf("ERR chunktype: 0x%x\n", chunkType)
|
||||||
r.err = ErrUnsupported
|
r.err = ErrUnsupported
|
||||||
return 0, r.err
|
return 0, r.err
|
||||||
}
|
}
|
||||||
// Section 4.4 Padding (chunk type 0xfe).
|
// Section 4.4 Padding (chunk type 0xfe).
|
||||||
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
|
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
|
||||||
if chunkLen > maxBlockSize {
|
if chunkLen > maxChunkSize {
|
||||||
|
// fmt.Printf("ERR chunkLen: 0x%x\n", chunkLen)
|
||||||
r.err = ErrUnsupported
|
r.err = ErrUnsupported
|
||||||
return 0, r.err
|
return 0, r.err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !r.skipN(r.buf, chunkLen, false) {
|
// fmt.Printf("skippable: ID: 0x%x, len: 0x%x\n", chunkType, chunkLen)
|
||||||
|
if !r.skippable(r.buf, chunkLen, false, chunkType) {
|
||||||
return 0, r.err
|
return 0, r.err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -396,7 +458,7 @@ func (r *Reader) Skip(n int64) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
n -= int64(r.j - r.i)
|
n -= int64(r.j - r.i)
|
||||||
r.i, r.j = 0, 0
|
r.i = r.j
|
||||||
}
|
}
|
||||||
|
|
||||||
// Buffer empty; read blocks until we have content.
|
// Buffer empty; read blocks until we have content.
|
||||||
|
@ -420,6 +482,7 @@ func (r *Reader) Skip(n int64) error {
|
||||||
// https://github.com/google/snappy/blob/master/framing_format.txt
|
// https://github.com/google/snappy/blob/master/framing_format.txt
|
||||||
switch chunkType {
|
switch chunkType {
|
||||||
case chunkTypeCompressedData:
|
case chunkTypeCompressedData:
|
||||||
|
r.blockStart += int64(r.j)
|
||||||
// Section 4.2. Compressed data (chunk type 0x00).
|
// Section 4.2. Compressed data (chunk type 0x00).
|
||||||
if chunkLen < checksumSize {
|
if chunkLen < checksumSize {
|
||||||
r.err = ErrCorrupt
|
r.err = ErrCorrupt
|
||||||
|
@ -468,6 +531,7 @@ func (r *Reader) Skip(n int64) error {
|
||||||
r.i, r.j = 0, dLen
|
r.i, r.j = 0, dLen
|
||||||
continue
|
continue
|
||||||
case chunkTypeUncompressedData:
|
case chunkTypeUncompressedData:
|
||||||
|
r.blockStart += int64(r.j)
|
||||||
// Section 4.3. Uncompressed data (chunk type 0x01).
|
// Section 4.3. Uncompressed data (chunk type 0x01).
|
||||||
if chunkLen < checksumSize {
|
if chunkLen < checksumSize {
|
||||||
r.err = ErrCorrupt
|
r.err = ErrCorrupt
|
||||||
|
@ -528,19 +592,138 @@ func (r *Reader) Skip(n int64) error {
|
||||||
r.err = ErrUnsupported
|
r.err = ErrUnsupported
|
||||||
return r.err
|
return r.err
|
||||||
}
|
}
|
||||||
if chunkLen > maxBlockSize {
|
if chunkLen > maxChunkSize {
|
||||||
r.err = ErrUnsupported
|
r.err = ErrUnsupported
|
||||||
return r.err
|
return r.err
|
||||||
}
|
}
|
||||||
// Section 4.4 Padding (chunk type 0xfe).
|
// Section 4.4 Padding (chunk type 0xfe).
|
||||||
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
|
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
|
||||||
if !r.skipN(r.buf, chunkLen, false) {
|
if !r.skippable(r.buf, chunkLen, false, chunkType) {
|
||||||
return r.err
|
return r.err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadSeeker provides random or forward seeking in compressed content.
|
||||||
|
// See Reader.ReadSeeker
|
||||||
|
type ReadSeeker struct {
|
||||||
|
*Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadSeeker will return an io.ReadSeeker compatible version of the reader.
|
||||||
|
// If 'random' is specified the returned io.Seeker can be used for
|
||||||
|
// random seeking, otherwise only forward seeking is supported.
|
||||||
|
// Enabling random seeking requires the original input to support
|
||||||
|
// the io.Seeker interface.
|
||||||
|
// A custom index can be specified which will be used if supplied.
|
||||||
|
// When using a custom index, it will not be read from the input stream.
|
||||||
|
// The returned ReadSeeker contains a shallow reference to the existing Reader,
|
||||||
|
// meaning changes performed to one is reflected in the other.
|
||||||
|
func (r *Reader) ReadSeeker(random bool, index []byte) (*ReadSeeker, error) {
|
||||||
|
// Read index if provided.
|
||||||
|
if len(index) != 0 {
|
||||||
|
if r.index == nil {
|
||||||
|
r.index = &Index{}
|
||||||
|
}
|
||||||
|
if _, err := r.index.Load(index); err != nil {
|
||||||
|
return nil, ErrCantSeek{Reason: "loading index returned: " + err.Error()}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if input is seekable
|
||||||
|
rs, ok := r.r.(io.ReadSeeker)
|
||||||
|
if !ok {
|
||||||
|
if !random {
|
||||||
|
return &ReadSeeker{Reader: r}, nil
|
||||||
|
}
|
||||||
|
return nil, ErrCantSeek{Reason: "input stream isn't seekable"}
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.index != nil {
|
||||||
|
// Seekable and index, ok...
|
||||||
|
return &ReadSeeker{Reader: r}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load from stream.
|
||||||
|
r.index = &Index{}
|
||||||
|
|
||||||
|
// Read current position.
|
||||||
|
pos, err := rs.Seek(0, io.SeekCurrent)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrCantSeek{Reason: "seeking input returned: " + err.Error()}
|
||||||
|
}
|
||||||
|
err = r.index.LoadStream(rs)
|
||||||
|
if err != nil {
|
||||||
|
if err == ErrUnsupported {
|
||||||
|
return nil, ErrCantSeek{Reason: "input stream does not contain an index"}
|
||||||
|
}
|
||||||
|
return nil, ErrCantSeek{Reason: "reading index returned: " + err.Error()}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reset position.
|
||||||
|
_, err = rs.Seek(pos, io.SeekStart)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrCantSeek{Reason: "seeking input returned: " + err.Error()}
|
||||||
|
}
|
||||||
|
return &ReadSeeker{Reader: r}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seek allows seeking in compressed data.
|
||||||
|
func (r *ReadSeeker) Seek(offset int64, whence int) (int64, error) {
|
||||||
|
if r.err != nil {
|
||||||
|
return 0, r.err
|
||||||
|
}
|
||||||
|
if offset == 0 && whence == io.SeekCurrent {
|
||||||
|
return r.blockStart + int64(r.i), nil
|
||||||
|
}
|
||||||
|
if !r.readHeader {
|
||||||
|
// Make sure we read the header.
|
||||||
|
_, r.err = r.Read([]byte{})
|
||||||
|
}
|
||||||
|
rs, ok := r.r.(io.ReadSeeker)
|
||||||
|
if r.index == nil || !ok {
|
||||||
|
if whence == io.SeekCurrent && offset >= 0 {
|
||||||
|
err := r.Skip(offset)
|
||||||
|
return r.blockStart + int64(r.i), err
|
||||||
|
}
|
||||||
|
if whence == io.SeekStart && offset >= r.blockStart+int64(r.i) {
|
||||||
|
err := r.Skip(offset - r.blockStart - int64(r.i))
|
||||||
|
return r.blockStart + int64(r.i), err
|
||||||
|
}
|
||||||
|
return 0, ErrUnsupported
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
switch whence {
|
||||||
|
case io.SeekCurrent:
|
||||||
|
offset += r.blockStart + int64(r.i)
|
||||||
|
case io.SeekEnd:
|
||||||
|
offset = -offset
|
||||||
|
}
|
||||||
|
c, u, err := r.index.Find(offset)
|
||||||
|
if err != nil {
|
||||||
|
return r.blockStart + int64(r.i), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Seek to next block
|
||||||
|
_, err = rs.Seek(c, io.SeekStart)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if offset < 0 {
|
||||||
|
offset = r.index.TotalUncompressed + offset
|
||||||
|
}
|
||||||
|
|
||||||
|
r.i = r.j // Remove rest of current block.
|
||||||
|
if u < offset {
|
||||||
|
// Forward inside block
|
||||||
|
return offset, r.Skip(offset - u)
|
||||||
|
}
|
||||||
|
return offset, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ReadByte satisfies the io.ByteReader interface.
|
// ReadByte satisfies the io.ByteReader interface.
|
||||||
func (r *Reader) ReadByte() (byte, error) {
|
func (r *Reader) ReadByte() (byte, error) {
|
||||||
if r.err != nil {
|
if r.err != nil {
|
||||||
|
@ -563,3 +746,17 @@ func (r *Reader) ReadByte() (byte, error) {
|
||||||
}
|
}
|
||||||
return 0, io.ErrNoProgress
|
return 0, io.ErrNoProgress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SkippableCB will register a callback for chunks with the specified ID.
|
||||||
|
// ID must be a Reserved skippable chunks ID, 0x80-0xfe (inclusive).
|
||||||
|
// For each chunk with the ID, the callback is called with the content.
|
||||||
|
// Any returned non-nil error will abort decompression.
|
||||||
|
// Only one callback per ID is supported, latest sent will be used.
|
||||||
|
// Sending a nil function will disable previous callbacks.
|
||||||
|
func (r *Reader) SkippableCB(id uint8, fn func(r io.Reader) error) error {
|
||||||
|
if id < 0x80 || id > chunkTypePadding {
|
||||||
|
return fmt.Errorf("ReaderSkippableCB: Invalid id provided, must be 0x80-0xfe (inclusive)")
|
||||||
|
}
|
||||||
|
r.skippableCB[id] = fn
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -399,6 +399,7 @@ type Writer struct {
|
||||||
obufLen int
|
obufLen int
|
||||||
concurrency int
|
concurrency int
|
||||||
written int64
|
written int64
|
||||||
|
uncompWritten int64 // Bytes sent to compression
|
||||||
output chan chan result
|
output chan chan result
|
||||||
buffers sync.Pool
|
buffers sync.Pool
|
||||||
pad int
|
pad int
|
||||||
|
@ -406,12 +407,14 @@ type Writer struct {
|
||||||
writer io.Writer
|
writer io.Writer
|
||||||
randSrc io.Reader
|
randSrc io.Reader
|
||||||
writerWg sync.WaitGroup
|
writerWg sync.WaitGroup
|
||||||
|
index Index
|
||||||
|
|
||||||
// wroteStreamHeader is whether we have written the stream header.
|
// wroteStreamHeader is whether we have written the stream header.
|
||||||
wroteStreamHeader bool
|
wroteStreamHeader bool
|
||||||
paramsOK bool
|
paramsOK bool
|
||||||
snappy bool
|
snappy bool
|
||||||
flushOnWrite bool
|
flushOnWrite bool
|
||||||
|
appendIndex bool
|
||||||
level uint8
|
level uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -422,7 +425,11 @@ const (
|
||||||
levelBest
|
levelBest
|
||||||
)
|
)
|
||||||
|
|
||||||
type result []byte
|
type result struct {
|
||||||
|
b []byte
|
||||||
|
// Uncompressed start offset
|
||||||
|
startOffset int64
|
||||||
|
}
|
||||||
|
|
||||||
// err returns the previously set error.
|
// err returns the previously set error.
|
||||||
// If no error has been set it is set to err if not nil.
|
// If no error has been set it is set to err if not nil.
|
||||||
|
@ -454,6 +461,9 @@ func (w *Writer) Reset(writer io.Writer) {
|
||||||
w.wroteStreamHeader = false
|
w.wroteStreamHeader = false
|
||||||
w.written = 0
|
w.written = 0
|
||||||
w.writer = writer
|
w.writer = writer
|
||||||
|
w.uncompWritten = 0
|
||||||
|
w.index.reset(w.blockSize)
|
||||||
|
|
||||||
// If we didn't get a writer, stop here.
|
// If we didn't get a writer, stop here.
|
||||||
if writer == nil {
|
if writer == nil {
|
||||||
return
|
return
|
||||||
|
@ -474,7 +484,8 @@ func (w *Writer) Reset(writer io.Writer) {
|
||||||
// Get a queued write.
|
// Get a queued write.
|
||||||
for write := range toWrite {
|
for write := range toWrite {
|
||||||
// Wait for the data to be available.
|
// Wait for the data to be available.
|
||||||
in := <-write
|
input := <-write
|
||||||
|
in := input.b
|
||||||
if len(in) > 0 {
|
if len(in) > 0 {
|
||||||
if w.err(nil) == nil {
|
if w.err(nil) == nil {
|
||||||
// Don't expose data from previous buffers.
|
// Don't expose data from previous buffers.
|
||||||
|
@ -485,11 +496,12 @@ func (w *Writer) Reset(writer io.Writer) {
|
||||||
err = io.ErrShortBuffer
|
err = io.ErrShortBuffer
|
||||||
}
|
}
|
||||||
_ = w.err(err)
|
_ = w.err(err)
|
||||||
|
w.err(w.index.add(w.written, input.startOffset))
|
||||||
w.written += int64(n)
|
w.written += int64(n)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if cap(in) >= w.obufLen {
|
if cap(in) >= w.obufLen {
|
||||||
w.buffers.Put([]byte(in))
|
w.buffers.Put(in)
|
||||||
}
|
}
|
||||||
// close the incoming write request.
|
// close the incoming write request.
|
||||||
// This can be used for synchronizing flushes.
|
// This can be used for synchronizing flushes.
|
||||||
|
@ -500,6 +512,9 @@ func (w *Writer) Reset(writer io.Writer) {
|
||||||
|
|
||||||
// Write satisfies the io.Writer interface.
|
// Write satisfies the io.Writer interface.
|
||||||
func (w *Writer) Write(p []byte) (nRet int, errRet error) {
|
func (w *Writer) Write(p []byte) (nRet int, errRet error) {
|
||||||
|
if err := w.err(nil); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
if w.flushOnWrite {
|
if w.flushOnWrite {
|
||||||
return w.write(p)
|
return w.write(p)
|
||||||
}
|
}
|
||||||
|
@ -535,6 +550,9 @@ func (w *Writer) Write(p []byte) (nRet int, errRet error) {
|
||||||
// The return value n is the number of bytes read.
|
// The return value n is the number of bytes read.
|
||||||
// Any error except io.EOF encountered during the read is also returned.
|
// Any error except io.EOF encountered during the read is also returned.
|
||||||
func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
|
func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
|
if err := w.err(nil); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
if len(w.ibuf) > 0 {
|
if len(w.ibuf) > 0 {
|
||||||
err := w.Flush()
|
err := w.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -577,6 +595,85 @@ func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
|
||||||
return n, w.err(nil)
|
return n, w.err(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddSkippableBlock will add a skippable block to the stream.
|
||||||
|
// The ID must be 0x80-0xfe (inclusive).
|
||||||
|
// Length of the skippable block must be <= 16777215 bytes.
|
||||||
|
func (w *Writer) AddSkippableBlock(id uint8, data []byte) (err error) {
|
||||||
|
if err := w.err(nil); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if id < 0x80 || id > chunkTypePadding {
|
||||||
|
return fmt.Errorf("invalid skippable block id %x", id)
|
||||||
|
}
|
||||||
|
if len(data) > maxChunkSize {
|
||||||
|
return fmt.Errorf("skippable block excessed maximum size")
|
||||||
|
}
|
||||||
|
var header [4]byte
|
||||||
|
chunkLen := 4 + len(data)
|
||||||
|
header[0] = id
|
||||||
|
header[1] = uint8(chunkLen >> 0)
|
||||||
|
header[2] = uint8(chunkLen >> 8)
|
||||||
|
header[3] = uint8(chunkLen >> 16)
|
||||||
|
if w.concurrency == 1 {
|
||||||
|
write := func(b []byte) error {
|
||||||
|
n, err := w.writer.Write(b)
|
||||||
|
if err = w.err(err); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if n != len(data) {
|
||||||
|
return w.err(io.ErrShortWrite)
|
||||||
|
}
|
||||||
|
w.written += int64(n)
|
||||||
|
return w.err(nil)
|
||||||
|
}
|
||||||
|
if !w.wroteStreamHeader {
|
||||||
|
w.wroteStreamHeader = true
|
||||||
|
if w.snappy {
|
||||||
|
if err := write([]byte(magicChunkSnappy)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := write([]byte(magicChunk)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := write(header[:]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := write(data); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create output...
|
||||||
|
if !w.wroteStreamHeader {
|
||||||
|
w.wroteStreamHeader = true
|
||||||
|
hWriter := make(chan result)
|
||||||
|
w.output <- hWriter
|
||||||
|
if w.snappy {
|
||||||
|
hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunkSnappy)}
|
||||||
|
} else {
|
||||||
|
hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunk)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy input.
|
||||||
|
inbuf := w.buffers.Get().([]byte)[:4]
|
||||||
|
copy(inbuf, header[:])
|
||||||
|
inbuf = append(inbuf, data...)
|
||||||
|
|
||||||
|
output := make(chan result, 1)
|
||||||
|
// Queue output.
|
||||||
|
w.output <- output
|
||||||
|
output <- result{startOffset: w.uncompWritten, b: inbuf}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// EncodeBuffer will add a buffer to the stream.
|
// EncodeBuffer will add a buffer to the stream.
|
||||||
// This is the fastest way to encode a stream,
|
// This is the fastest way to encode a stream,
|
||||||
// but the input buffer cannot be written to by the caller
|
// but the input buffer cannot be written to by the caller
|
||||||
|
@ -614,9 +711,9 @@ func (w *Writer) EncodeBuffer(buf []byte) (err error) {
|
||||||
hWriter := make(chan result)
|
hWriter := make(chan result)
|
||||||
w.output <- hWriter
|
w.output <- hWriter
|
||||||
if w.snappy {
|
if w.snappy {
|
||||||
hWriter <- []byte(magicChunkSnappy)
|
hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunkSnappy)}
|
||||||
} else {
|
} else {
|
||||||
hWriter <- []byte(magicChunk)
|
hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunk)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -632,6 +729,10 @@ func (w *Writer) EncodeBuffer(buf []byte) (err error) {
|
||||||
output := make(chan result)
|
output := make(chan result)
|
||||||
// Queue output now, so we keep order.
|
// Queue output now, so we keep order.
|
||||||
w.output <- output
|
w.output <- output
|
||||||
|
res := result{
|
||||||
|
startOffset: w.uncompWritten,
|
||||||
|
}
|
||||||
|
w.uncompWritten += int64(len(uncompressed))
|
||||||
go func() {
|
go func() {
|
||||||
checksum := crc(uncompressed)
|
checksum := crc(uncompressed)
|
||||||
|
|
||||||
|
@ -664,7 +765,8 @@ func (w *Writer) EncodeBuffer(buf []byte) (err error) {
|
||||||
obuf[7] = uint8(checksum >> 24)
|
obuf[7] = uint8(checksum >> 24)
|
||||||
|
|
||||||
// Queue final output.
|
// Queue final output.
|
||||||
output <- obuf
|
res.b = obuf
|
||||||
|
output <- res
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -708,9 +810,9 @@ func (w *Writer) write(p []byte) (nRet int, errRet error) {
|
||||||
hWriter := make(chan result)
|
hWriter := make(chan result)
|
||||||
w.output <- hWriter
|
w.output <- hWriter
|
||||||
if w.snappy {
|
if w.snappy {
|
||||||
hWriter <- []byte(magicChunkSnappy)
|
hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunkSnappy)}
|
||||||
} else {
|
} else {
|
||||||
hWriter <- []byte(magicChunk)
|
hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunk)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -731,6 +833,11 @@ func (w *Writer) write(p []byte) (nRet int, errRet error) {
|
||||||
output := make(chan result)
|
output := make(chan result)
|
||||||
// Queue output now, so we keep order.
|
// Queue output now, so we keep order.
|
||||||
w.output <- output
|
w.output <- output
|
||||||
|
res := result{
|
||||||
|
startOffset: w.uncompWritten,
|
||||||
|
}
|
||||||
|
w.uncompWritten += int64(len(uncompressed))
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
checksum := crc(uncompressed)
|
checksum := crc(uncompressed)
|
||||||
|
|
||||||
|
@ -763,7 +870,8 @@ func (w *Writer) write(p []byte) (nRet int, errRet error) {
|
||||||
obuf[7] = uint8(checksum >> 24)
|
obuf[7] = uint8(checksum >> 24)
|
||||||
|
|
||||||
// Queue final output.
|
// Queue final output.
|
||||||
output <- obuf
|
res.b = obuf
|
||||||
|
output <- res
|
||||||
|
|
||||||
// Put unused buffer back in pool.
|
// Put unused buffer back in pool.
|
||||||
w.buffers.Put(inbuf)
|
w.buffers.Put(inbuf)
|
||||||
|
@ -793,9 +901,9 @@ func (w *Writer) writeFull(inbuf []byte) (errRet error) {
|
||||||
hWriter := make(chan result)
|
hWriter := make(chan result)
|
||||||
w.output <- hWriter
|
w.output <- hWriter
|
||||||
if w.snappy {
|
if w.snappy {
|
||||||
hWriter <- []byte(magicChunkSnappy)
|
hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunkSnappy)}
|
||||||
} else {
|
} else {
|
||||||
hWriter <- []byte(magicChunk)
|
hWriter <- result{startOffset: w.uncompWritten, b: []byte(magicChunk)}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -806,6 +914,11 @@ func (w *Writer) writeFull(inbuf []byte) (errRet error) {
|
||||||
output := make(chan result)
|
output := make(chan result)
|
||||||
// Queue output now, so we keep order.
|
// Queue output now, so we keep order.
|
||||||
w.output <- output
|
w.output <- output
|
||||||
|
res := result{
|
||||||
|
startOffset: w.uncompWritten,
|
||||||
|
}
|
||||||
|
w.uncompWritten += int64(len(uncompressed))
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
checksum := crc(uncompressed)
|
checksum := crc(uncompressed)
|
||||||
|
|
||||||
|
@ -838,7 +951,8 @@ func (w *Writer) writeFull(inbuf []byte) (errRet error) {
|
||||||
obuf[7] = uint8(checksum >> 24)
|
obuf[7] = uint8(checksum >> 24)
|
||||||
|
|
||||||
// Queue final output.
|
// Queue final output.
|
||||||
output <- obuf
|
res.b = obuf
|
||||||
|
output <- res
|
||||||
|
|
||||||
// Put unused buffer back in pool.
|
// Put unused buffer back in pool.
|
||||||
w.buffers.Put(inbuf)
|
w.buffers.Put(inbuf)
|
||||||
|
@ -912,7 +1026,10 @@ func (w *Writer) writeSync(p []byte) (nRet int, errRet error) {
|
||||||
if n != len(obuf) {
|
if n != len(obuf) {
|
||||||
return 0, w.err(io.ErrShortWrite)
|
return 0, w.err(io.ErrShortWrite)
|
||||||
}
|
}
|
||||||
|
w.err(w.index.add(w.written, w.uncompWritten))
|
||||||
w.written += int64(n)
|
w.written += int64(n)
|
||||||
|
w.uncompWritten += int64(len(uncompressed))
|
||||||
|
|
||||||
if chunkType == chunkTypeUncompressedData {
|
if chunkType == chunkTypeUncompressedData {
|
||||||
// Write uncompressed data.
|
// Write uncompressed data.
|
||||||
n, err := w.writer.Write(uncompressed)
|
n, err := w.writer.Write(uncompressed)
|
||||||
|
@ -961,39 +1078,88 @@ func (w *Writer) Flush() error {
|
||||||
res := make(chan result)
|
res := make(chan result)
|
||||||
w.output <- res
|
w.output <- res
|
||||||
// Block until this has been picked up.
|
// Block until this has been picked up.
|
||||||
res <- nil
|
res <- result{b: nil, startOffset: w.uncompWritten}
|
||||||
// When it is closed, we have flushed.
|
// When it is closed, we have flushed.
|
||||||
<-res
|
<-res
|
||||||
return w.err(nil)
|
return w.err(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close calls Flush and then closes the Writer.
|
// Close calls Flush and then closes the Writer.
|
||||||
// Calling Close multiple times is ok.
|
// Calling Close multiple times is ok,
|
||||||
|
// but calling CloseIndex after this will make it not return the index.
|
||||||
func (w *Writer) Close() error {
|
func (w *Writer) Close() error {
|
||||||
|
_, err := w.closeIndex(w.appendIndex)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseIndex calls Close and returns an index on first call.
|
||||||
|
// This is not required if you are only adding index to a stream.
|
||||||
|
func (w *Writer) CloseIndex() ([]byte, error) {
|
||||||
|
return w.closeIndex(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *Writer) closeIndex(idx bool) ([]byte, error) {
|
||||||
err := w.Flush()
|
err := w.Flush()
|
||||||
if w.output != nil {
|
if w.output != nil {
|
||||||
close(w.output)
|
close(w.output)
|
||||||
w.writerWg.Wait()
|
w.writerWg.Wait()
|
||||||
w.output = nil
|
w.output = nil
|
||||||
}
|
}
|
||||||
if w.err(nil) == nil && w.writer != nil && w.pad > 0 {
|
|
||||||
add := calcSkippableFrame(w.written, int64(w.pad))
|
var index []byte
|
||||||
frame, err := skippableFrame(w.ibuf[:0], add, w.randSrc)
|
if w.err(nil) == nil && w.writer != nil {
|
||||||
if err = w.err(err); err != nil {
|
// Create index.
|
||||||
return err
|
if idx {
|
||||||
|
compSize := int64(-1)
|
||||||
|
if w.pad <= 1 {
|
||||||
|
compSize = w.written
|
||||||
|
}
|
||||||
|
index = w.index.appendTo(w.ibuf[:0], w.uncompWritten, compSize)
|
||||||
|
// Count as written for padding.
|
||||||
|
if w.appendIndex {
|
||||||
|
w.written += int64(len(index))
|
||||||
|
}
|
||||||
|
if true {
|
||||||
|
_, err := w.index.Load(index)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.pad > 1 {
|
||||||
|
tmp := w.ibuf[:0]
|
||||||
|
if len(index) > 0 {
|
||||||
|
// Allocate another buffer.
|
||||||
|
tmp = w.buffers.Get().([]byte)[:0]
|
||||||
|
defer w.buffers.Put(tmp)
|
||||||
|
}
|
||||||
|
add := calcSkippableFrame(w.written, int64(w.pad))
|
||||||
|
frame, err := skippableFrame(tmp, add, w.randSrc)
|
||||||
|
if err = w.err(err); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
n, err2 := w.writer.Write(frame)
|
||||||
|
if err2 == nil && n != len(frame) {
|
||||||
|
err2 = io.ErrShortWrite
|
||||||
}
|
}
|
||||||
_, err2 := w.writer.Write(frame)
|
|
||||||
_ = w.err(err2)
|
_ = w.err(err2)
|
||||||
}
|
}
|
||||||
_ = w.err(errClosed)
|
if len(index) > 0 && w.appendIndex {
|
||||||
if err == errClosed {
|
n, err2 := w.writer.Write(index)
|
||||||
return nil
|
if err2 == nil && n != len(index) {
|
||||||
|
err2 = io.ErrShortWrite
|
||||||
}
|
}
|
||||||
return err
|
_ = w.err(err2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = w.err(errClosed)
|
||||||
|
if err == errClosed {
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const skippableFrameHeader = 4
|
|
||||||
|
|
||||||
// calcSkippableFrame will return a total size to be added for written
|
// calcSkippableFrame will return a total size to be added for written
|
||||||
// to be divisible by multiple.
|
// to be divisible by multiple.
|
||||||
// The value will always be > skippableFrameHeader.
|
// The value will always be > skippableFrameHeader.
|
||||||
|
@ -1057,6 +1223,15 @@ func WriterConcurrency(n int) WriterOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// WriterAddIndex will append an index to the end of a stream
|
||||||
|
// when it is closed.
|
||||||
|
func WriterAddIndex() WriterOption {
|
||||||
|
return func(w *Writer) error {
|
||||||
|
w.appendIndex = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WriterBetterCompression will enable better compression.
|
// WriterBetterCompression will enable better compression.
|
||||||
// EncodeBetter compresses better than Encode but typically with a
|
// EncodeBetter compresses better than Encode but typically with a
|
||||||
// 10-40% speed decrease on both compression and decompression.
|
// 10-40% speed decrease on both compression and decompression.
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
// Code generated by command: go run gen.go -out ../encodeblock_amd64.s -stubs ../encodeblock_amd64.go -pkg=s2. DO NOT EDIT.
|
// Code generated by command: go run gen.go -out ../encodeblock_amd64.s -stubs ../encodeblock_amd64.go -pkg=s2. DO NOT EDIT.
|
||||||
|
|
||||||
//go:build !appengine && !noasm && gc
|
//go:build !appengine && !noasm && gc && !noasm
|
||||||
// +build !appengine,!noasm,gc
|
// +build !appengine,!noasm,gc,!noasm
|
||||||
|
|
||||||
package s2
|
package s2
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,525 @@
|
||||||
|
// Copyright (c) 2022+ Klaus Post. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package s2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
S2IndexHeader = "s2idx\x00"
|
||||||
|
S2IndexTrailer = "\x00xdi2s"
|
||||||
|
maxIndexEntries = 1 << 16
|
||||||
|
)
|
||||||
|
|
||||||
|
// Index represents an S2/Snappy index.
|
||||||
|
type Index struct {
|
||||||
|
TotalUncompressed int64 // Total Uncompressed size if known. Will be -1 if unknown.
|
||||||
|
TotalCompressed int64 // Total Compressed size if known. Will be -1 if unknown.
|
||||||
|
info []struct {
|
||||||
|
compressedOffset int64
|
||||||
|
uncompressedOffset int64
|
||||||
|
}
|
||||||
|
estBlockUncomp int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Index) reset(maxBlock int) {
|
||||||
|
i.estBlockUncomp = int64(maxBlock)
|
||||||
|
i.TotalCompressed = -1
|
||||||
|
i.TotalUncompressed = -1
|
||||||
|
if len(i.info) > 0 {
|
||||||
|
i.info = i.info[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// allocInfos will allocate an empty slice of infos.
|
||||||
|
func (i *Index) allocInfos(n int) {
|
||||||
|
if n > maxIndexEntries {
|
||||||
|
panic("n > maxIndexEntries")
|
||||||
|
}
|
||||||
|
i.info = make([]struct {
|
||||||
|
compressedOffset int64
|
||||||
|
uncompressedOffset int64
|
||||||
|
}, 0, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// add an uncompressed and compressed pair.
|
||||||
|
// Entries must be sent in order.
|
||||||
|
func (i *Index) add(compressedOffset, uncompressedOffset int64) error {
|
||||||
|
if i == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
lastIdx := len(i.info) - 1
|
||||||
|
if lastIdx >= 0 {
|
||||||
|
latest := i.info[lastIdx]
|
||||||
|
if latest.uncompressedOffset == uncompressedOffset {
|
||||||
|
// Uncompressed didn't change, don't add entry,
|
||||||
|
// but update start index.
|
||||||
|
latest.compressedOffset = compressedOffset
|
||||||
|
i.info[lastIdx] = latest
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if latest.uncompressedOffset > uncompressedOffset {
|
||||||
|
return fmt.Errorf("internal error: Earlier uncompressed received (%d > %d)", latest.uncompressedOffset, uncompressedOffset)
|
||||||
|
}
|
||||||
|
if latest.compressedOffset > compressedOffset {
|
||||||
|
return fmt.Errorf("internal error: Earlier compressed received (%d > %d)", latest.uncompressedOffset, uncompressedOffset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i.info = append(i.info, struct {
|
||||||
|
compressedOffset int64
|
||||||
|
uncompressedOffset int64
|
||||||
|
}{compressedOffset: compressedOffset, uncompressedOffset: uncompressedOffset})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the offset at or before the wanted (uncompressed) offset.
|
||||||
|
// If offset is 0 or positive it is the offset from the beginning of the file.
|
||||||
|
// If the uncompressed size is known, the offset must be within the file.
|
||||||
|
// If an offset outside the file is requested io.ErrUnexpectedEOF is returned.
|
||||||
|
// If the offset is negative, it is interpreted as the distance from the end of the file,
|
||||||
|
// where -1 represents the last byte.
|
||||||
|
// If offset from the end of the file is requested, but size is unknown,
|
||||||
|
// ErrUnsupported will be returned.
|
||||||
|
func (i *Index) Find(offset int64) (compressedOff, uncompressedOff int64, err error) {
|
||||||
|
if i.TotalUncompressed < 0 {
|
||||||
|
return 0, 0, ErrCorrupt
|
||||||
|
}
|
||||||
|
if offset < 0 {
|
||||||
|
offset = i.TotalUncompressed + offset
|
||||||
|
if offset < 0 {
|
||||||
|
return 0, 0, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if offset > i.TotalUncompressed {
|
||||||
|
return 0, 0, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
for _, info := range i.info {
|
||||||
|
if info.uncompressedOffset > offset {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
compressedOff = info.compressedOffset
|
||||||
|
uncompressedOff = info.uncompressedOffset
|
||||||
|
}
|
||||||
|
return compressedOff, uncompressedOff, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce to stay below maxIndexEntries
|
||||||
|
func (i *Index) reduce() {
|
||||||
|
if len(i.info) < maxIndexEntries && i.estBlockUncomp >= 1<<20 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Algorithm, keep 1, remove removeN entries...
|
||||||
|
removeN := (len(i.info) + 1) / maxIndexEntries
|
||||||
|
src := i.info
|
||||||
|
j := 0
|
||||||
|
|
||||||
|
// Each block should be at least 1MB, but don't reduce below 1000 entries.
|
||||||
|
for i.estBlockUncomp*(int64(removeN)+1) < 1<<20 && len(i.info)/(removeN+1) > 1000 {
|
||||||
|
removeN++
|
||||||
|
}
|
||||||
|
for idx := 0; idx < len(src); idx++ {
|
||||||
|
i.info[j] = src[idx]
|
||||||
|
j++
|
||||||
|
idx += removeN
|
||||||
|
}
|
||||||
|
i.info = i.info[:j]
|
||||||
|
// Update maxblock estimate.
|
||||||
|
i.estBlockUncomp += i.estBlockUncomp * int64(removeN)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Index) appendTo(b []byte, uncompTotal, compTotal int64) []byte {
|
||||||
|
i.reduce()
|
||||||
|
var tmp [binary.MaxVarintLen64]byte
|
||||||
|
|
||||||
|
initSize := len(b)
|
||||||
|
// We make the start a skippable header+size.
|
||||||
|
b = append(b, ChunkTypeIndex, 0, 0, 0)
|
||||||
|
b = append(b, []byte(S2IndexHeader)...)
|
||||||
|
// Total Uncompressed size
|
||||||
|
n := binary.PutVarint(tmp[:], uncompTotal)
|
||||||
|
b = append(b, tmp[:n]...)
|
||||||
|
// Total Compressed size
|
||||||
|
n = binary.PutVarint(tmp[:], compTotal)
|
||||||
|
b = append(b, tmp[:n]...)
|
||||||
|
// Put EstBlockUncomp size
|
||||||
|
n = binary.PutVarint(tmp[:], i.estBlockUncomp)
|
||||||
|
b = append(b, tmp[:n]...)
|
||||||
|
// Put length
|
||||||
|
n = binary.PutVarint(tmp[:], int64(len(i.info)))
|
||||||
|
b = append(b, tmp[:n]...)
|
||||||
|
|
||||||
|
// Check if we should add uncompressed offsets
|
||||||
|
var hasUncompressed byte
|
||||||
|
for idx, info := range i.info {
|
||||||
|
if idx == 0 {
|
||||||
|
if info.uncompressedOffset != 0 {
|
||||||
|
hasUncompressed = 1
|
||||||
|
break
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if info.uncompressedOffset != i.info[idx-1].uncompressedOffset+i.estBlockUncomp {
|
||||||
|
hasUncompressed = 1
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b = append(b, hasUncompressed)
|
||||||
|
|
||||||
|
// Add each entry
|
||||||
|
if hasUncompressed == 1 {
|
||||||
|
for idx, info := range i.info {
|
||||||
|
uOff := info.uncompressedOffset
|
||||||
|
if idx > 0 {
|
||||||
|
prev := i.info[idx-1]
|
||||||
|
uOff -= prev.uncompressedOffset + (i.estBlockUncomp)
|
||||||
|
}
|
||||||
|
n = binary.PutVarint(tmp[:], uOff)
|
||||||
|
b = append(b, tmp[:n]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initial compressed size estimate.
|
||||||
|
cPredict := i.estBlockUncomp / 2
|
||||||
|
|
||||||
|
for idx, info := range i.info {
|
||||||
|
cOff := info.compressedOffset
|
||||||
|
if idx > 0 {
|
||||||
|
prev := i.info[idx-1]
|
||||||
|
cOff -= prev.compressedOffset + cPredict
|
||||||
|
// Update compressed size prediction, with half the error.
|
||||||
|
cPredict += cOff / 2
|
||||||
|
}
|
||||||
|
n = binary.PutVarint(tmp[:], cOff)
|
||||||
|
b = append(b, tmp[:n]...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add Total Size.
|
||||||
|
// Stored as fixed size for easier reading.
|
||||||
|
binary.LittleEndian.PutUint32(tmp[:], uint32(len(b)-initSize+4+len(S2IndexTrailer)))
|
||||||
|
b = append(b, tmp[:4]...)
|
||||||
|
// Trailer
|
||||||
|
b = append(b, []byte(S2IndexTrailer)...)
|
||||||
|
|
||||||
|
// Update size
|
||||||
|
chunkLen := len(b) - initSize - skippableFrameHeader
|
||||||
|
b[initSize+1] = uint8(chunkLen >> 0)
|
||||||
|
b[initSize+2] = uint8(chunkLen >> 8)
|
||||||
|
b[initSize+3] = uint8(chunkLen >> 16)
|
||||||
|
//fmt.Printf("chunklen: 0x%x Uncomp:%d, Comp:%d\n", chunkLen, uncompTotal, compTotal)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load a binary index.
|
||||||
|
// A zero value Index can be used or a previous one can be reused.
|
||||||
|
func (i *Index) Load(b []byte) ([]byte, error) {
|
||||||
|
if len(b) <= 4+len(S2IndexHeader)+len(S2IndexTrailer) {
|
||||||
|
return b, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
if b[0] != ChunkTypeIndex {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
}
|
||||||
|
chunkLen := int(b[1]) | int(b[2])<<8 | int(b[3])<<16
|
||||||
|
b = b[4:]
|
||||||
|
|
||||||
|
// Validate we have enough...
|
||||||
|
if len(b) < chunkLen {
|
||||||
|
return b, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
if !bytes.Equal(b[:len(S2IndexHeader)], []byte(S2IndexHeader)) {
|
||||||
|
return b, ErrUnsupported
|
||||||
|
}
|
||||||
|
b = b[len(S2IndexHeader):]
|
||||||
|
|
||||||
|
// Total Uncompressed
|
||||||
|
if v, n := binary.Varint(b); n <= 0 || v < 0 {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
} else {
|
||||||
|
i.TotalUncompressed = v
|
||||||
|
b = b[n:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Total Compressed
|
||||||
|
if v, n := binary.Varint(b); n <= 0 {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
} else {
|
||||||
|
i.TotalCompressed = v
|
||||||
|
b = b[n:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read EstBlockUncomp
|
||||||
|
if v, n := binary.Varint(b); n <= 0 {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
} else {
|
||||||
|
if v < 0 {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
}
|
||||||
|
i.estBlockUncomp = v
|
||||||
|
b = b[n:]
|
||||||
|
}
|
||||||
|
|
||||||
|
var entries int
|
||||||
|
if v, n := binary.Varint(b); n <= 0 {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
} else {
|
||||||
|
if v < 0 || v > maxIndexEntries {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
}
|
||||||
|
entries = int(v)
|
||||||
|
b = b[n:]
|
||||||
|
}
|
||||||
|
if cap(i.info) < entries {
|
||||||
|
i.allocInfos(entries)
|
||||||
|
}
|
||||||
|
i.info = i.info[:entries]
|
||||||
|
|
||||||
|
if len(b) < 1 {
|
||||||
|
return b, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
hasUncompressed := b[0]
|
||||||
|
b = b[1:]
|
||||||
|
if hasUncompressed&1 != hasUncompressed {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add each uncompressed entry
|
||||||
|
for idx := range i.info {
|
||||||
|
var uOff int64
|
||||||
|
if hasUncompressed != 0 {
|
||||||
|
// Load delta
|
||||||
|
if v, n := binary.Varint(b); n <= 0 {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
} else {
|
||||||
|
uOff = v
|
||||||
|
b = b[n:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if idx > 0 {
|
||||||
|
prev := i.info[idx-1].uncompressedOffset
|
||||||
|
uOff += prev + (i.estBlockUncomp)
|
||||||
|
if uOff <= prev {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if uOff < 0 {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
}
|
||||||
|
i.info[idx].uncompressedOffset = uOff
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initial compressed size estimate.
|
||||||
|
cPredict := i.estBlockUncomp / 2
|
||||||
|
|
||||||
|
// Add each compressed entry
|
||||||
|
for idx := range i.info {
|
||||||
|
var cOff int64
|
||||||
|
if v, n := binary.Varint(b); n <= 0 {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
} else {
|
||||||
|
cOff = v
|
||||||
|
b = b[n:]
|
||||||
|
}
|
||||||
|
|
||||||
|
if idx > 0 {
|
||||||
|
// Update compressed size prediction, with half the error.
|
||||||
|
cPredictNew := cPredict + cOff/2
|
||||||
|
|
||||||
|
prev := i.info[idx-1].compressedOffset
|
||||||
|
cOff += prev + cPredict
|
||||||
|
if cOff <= prev {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
}
|
||||||
|
cPredict = cPredictNew
|
||||||
|
}
|
||||||
|
if cOff < 0 {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
}
|
||||||
|
i.info[idx].compressedOffset = cOff
|
||||||
|
}
|
||||||
|
if len(b) < 4+len(S2IndexTrailer) {
|
||||||
|
return b, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
// Skip size...
|
||||||
|
b = b[4:]
|
||||||
|
|
||||||
|
// Check trailer...
|
||||||
|
if !bytes.Equal(b[:len(S2IndexTrailer)], []byte(S2IndexTrailer)) {
|
||||||
|
return b, ErrCorrupt
|
||||||
|
}
|
||||||
|
return b[len(S2IndexTrailer):], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadStream will load an index from the end of the supplied stream.
|
||||||
|
// ErrUnsupported will be returned if the signature cannot be found.
|
||||||
|
// ErrCorrupt will be returned if unexpected values are found.
|
||||||
|
// io.ErrUnexpectedEOF is returned if there are too few bytes.
|
||||||
|
// IO errors are returned as-is.
|
||||||
|
func (i *Index) LoadStream(rs io.ReadSeeker) error {
|
||||||
|
// Go to end.
|
||||||
|
_, err := rs.Seek(-10, io.SeekEnd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
var tmp [10]byte
|
||||||
|
_, err = io.ReadFull(rs, tmp[:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Check trailer...
|
||||||
|
if !bytes.Equal(tmp[4:4+len(S2IndexTrailer)], []byte(S2IndexTrailer)) {
|
||||||
|
return ErrUnsupported
|
||||||
|
}
|
||||||
|
sz := binary.LittleEndian.Uint32(tmp[:4])
|
||||||
|
if sz > maxChunkSize+skippableFrameHeader {
|
||||||
|
return ErrCorrupt
|
||||||
|
}
|
||||||
|
_, err = rs.Seek(-int64(sz), io.SeekEnd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read index.
|
||||||
|
buf := make([]byte, sz)
|
||||||
|
_, err = io.ReadFull(rs, buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = i.Load(buf)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// IndexStream will return an index for a stream.
|
||||||
|
// The stream structure will be checked, but
|
||||||
|
// data within blocks is not verified.
|
||||||
|
// The returned index can either be appended to the end of the stream
|
||||||
|
// or stored separately.
|
||||||
|
func IndexStream(r io.Reader) ([]byte, error) {
|
||||||
|
var i Index
|
||||||
|
var buf [maxChunkSize]byte
|
||||||
|
var readHeader bool
|
||||||
|
for {
|
||||||
|
_, err := io.ReadFull(r, buf[:4])
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF {
|
||||||
|
return i.appendTo(nil, i.TotalUncompressed, i.TotalCompressed), nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Start of this chunk.
|
||||||
|
startChunk := i.TotalCompressed
|
||||||
|
i.TotalCompressed += 4
|
||||||
|
|
||||||
|
chunkType := buf[0]
|
||||||
|
if !readHeader {
|
||||||
|
if chunkType != chunkTypeStreamIdentifier {
|
||||||
|
return nil, ErrCorrupt
|
||||||
|
}
|
||||||
|
readHeader = true
|
||||||
|
}
|
||||||
|
chunkLen := int(buf[1]) | int(buf[2])<<8 | int(buf[3])<<16
|
||||||
|
if chunkLen < checksumSize {
|
||||||
|
return nil, ErrCorrupt
|
||||||
|
}
|
||||||
|
|
||||||
|
i.TotalCompressed += int64(chunkLen)
|
||||||
|
_, err = io.ReadFull(r, buf[:chunkLen])
|
||||||
|
if err != nil {
|
||||||
|
return nil, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
// The chunk types are specified at
|
||||||
|
// https://github.com/google/snappy/blob/master/framing_format.txt
|
||||||
|
switch chunkType {
|
||||||
|
case chunkTypeCompressedData:
|
||||||
|
// Section 4.2. Compressed data (chunk type 0x00).
|
||||||
|
// Skip checksum.
|
||||||
|
dLen, err := DecodedLen(buf[checksumSize:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if dLen > maxBlockSize {
|
||||||
|
return nil, ErrCorrupt
|
||||||
|
}
|
||||||
|
if i.estBlockUncomp == 0 {
|
||||||
|
// Use first block for estimate...
|
||||||
|
i.estBlockUncomp = int64(dLen)
|
||||||
|
}
|
||||||
|
err = i.add(startChunk, i.TotalUncompressed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
i.TotalUncompressed += int64(dLen)
|
||||||
|
continue
|
||||||
|
case chunkTypeUncompressedData:
|
||||||
|
n2 := chunkLen - checksumSize
|
||||||
|
if n2 > maxBlockSize {
|
||||||
|
return nil, ErrCorrupt
|
||||||
|
}
|
||||||
|
if i.estBlockUncomp == 0 {
|
||||||
|
// Use first block for estimate...
|
||||||
|
i.estBlockUncomp = int64(n2)
|
||||||
|
}
|
||||||
|
err = i.add(startChunk, i.TotalUncompressed)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
i.TotalUncompressed += int64(n2)
|
||||||
|
continue
|
||||||
|
case chunkTypeStreamIdentifier:
|
||||||
|
// Section 4.1. Stream identifier (chunk type 0xff).
|
||||||
|
if chunkLen != len(magicBody) {
|
||||||
|
return nil, ErrCorrupt
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(buf[:len(magicBody)]) != magicBody {
|
||||||
|
if string(buf[:len(magicBody)]) != magicBodySnappy {
|
||||||
|
return nil, ErrCorrupt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if chunkType <= 0x7f {
|
||||||
|
// Section 4.5. Reserved unskippable chunks (chunk types 0x02-0x7f).
|
||||||
|
return nil, ErrUnsupported
|
||||||
|
}
|
||||||
|
if chunkLen > maxChunkSize {
|
||||||
|
return nil, ErrUnsupported
|
||||||
|
}
|
||||||
|
// Section 4.4 Padding (chunk type 0xfe).
|
||||||
|
// Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// JSON returns the index as JSON text.
|
||||||
|
func (i *Index) JSON() []byte {
|
||||||
|
x := struct {
|
||||||
|
TotalUncompressed int64 `json:"total_uncompressed"` // Total Uncompressed size if known. Will be -1 if unknown.
|
||||||
|
TotalCompressed int64 `json:"total_compressed"` // Total Compressed size if known. Will be -1 if unknown.
|
||||||
|
Offsets []struct {
|
||||||
|
CompressedOffset int64 `json:"compressed"`
|
||||||
|
UncompressedOffset int64 `json:"uncompressed"`
|
||||||
|
} `json:"offsets"`
|
||||||
|
EstBlockUncomp int64 `json:"est_block_uncompressed"`
|
||||||
|
}{
|
||||||
|
TotalUncompressed: i.TotalUncompressed,
|
||||||
|
TotalCompressed: i.TotalCompressed,
|
||||||
|
EstBlockUncomp: i.estBlockUncomp,
|
||||||
|
}
|
||||||
|
for _, v := range i.info {
|
||||||
|
x.Offsets = append(x.Offsets, struct {
|
||||||
|
CompressedOffset int64 `json:"compressed"`
|
||||||
|
UncompressedOffset int64 `json:"uncompressed"`
|
||||||
|
}{CompressedOffset: v.compressedOffset, UncompressedOffset: v.uncompressedOffset})
|
||||||
|
}
|
||||||
|
b, _ := json.MarshalIndent(x, "", " ")
|
||||||
|
return b
|
||||||
|
}
|
|
@ -87,6 +87,9 @@ const (
|
||||||
// minBlockSize is the minimum size of block setting when creating a writer.
|
// minBlockSize is the minimum size of block setting when creating a writer.
|
||||||
minBlockSize = 4 << 10
|
minBlockSize = 4 << 10
|
||||||
|
|
||||||
|
skippableFrameHeader = 4
|
||||||
|
maxChunkSize = 1<<24 - 1 // 16777215
|
||||||
|
|
||||||
// Default block size
|
// Default block size
|
||||||
defaultBlockSize = 1 << 20
|
defaultBlockSize = 1 << 20
|
||||||
|
|
||||||
|
@ -99,6 +102,7 @@ const (
|
||||||
const (
|
const (
|
||||||
chunkTypeCompressedData = 0x00
|
chunkTypeCompressedData = 0x00
|
||||||
chunkTypeUncompressedData = 0x01
|
chunkTypeUncompressedData = 0x01
|
||||||
|
ChunkTypeIndex = 0x99
|
||||||
chunkTypePadding = 0xfe
|
chunkTypePadding = 0xfe
|
||||||
chunkTypeStreamIdentifier = 0xff
|
chunkTypeStreamIdentifier = 0xff
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||||
|
*.o
|
||||||
|
*.a
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Folders
|
||||||
|
_obj
|
||||||
|
_test
|
||||||
|
|
||||||
|
# Architecture specific extensions/prefixes
|
||||||
|
*.[568vq]
|
||||||
|
[568vq].out
|
||||||
|
|
||||||
|
*.cgo1.go
|
||||||
|
*.cgo2.c
|
||||||
|
_cgo_defun.c
|
||||||
|
_cgo_gotypes.go
|
||||||
|
_cgo_export.*
|
||||||
|
|
||||||
|
_testmain.go
|
||||||
|
|
||||||
|
*.exe
|
||||||
|
*.test
|
||||||
|
*.prof
|
|
@ -0,0 +1,10 @@
|
||||||
|
language: go
|
||||||
|
go_import_path: github.com/pkg/errors
|
||||||
|
go:
|
||||||
|
- 1.11.x
|
||||||
|
- 1.12.x
|
||||||
|
- 1.13.x
|
||||||
|
- tip
|
||||||
|
|
||||||
|
script:
|
||||||
|
- make check
|
|
@ -0,0 +1,23 @@
|
||||||
|
Copyright (c) 2015, Dave Cheney <dave@cheney.net>
|
||||||
|
All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright notice, this
|
||||||
|
list of conditions and the following disclaimer.
|
||||||
|
|
||||||
|
* Redistributions in binary form must reproduce the above copyright notice,
|
||||||
|
this list of conditions and the following disclaimer in the documentation
|
||||||
|
and/or other materials provided with the distribution.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||||
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||||
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||||
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||||
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||||
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||||
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||||
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,44 @@
|
||||||
|
PKGS := github.com/pkg/errors
|
||||||
|
SRCDIRS := $(shell go list -f '{{.Dir}}' $(PKGS))
|
||||||
|
GO := go
|
||||||
|
|
||||||
|
check: test vet gofmt misspell unconvert staticcheck ineffassign unparam
|
||||||
|
|
||||||
|
test:
|
||||||
|
$(GO) test $(PKGS)
|
||||||
|
|
||||||
|
vet: | test
|
||||||
|
$(GO) vet $(PKGS)
|
||||||
|
|
||||||
|
staticcheck:
|
||||||
|
$(GO) get honnef.co/go/tools/cmd/staticcheck
|
||||||
|
staticcheck -checks all $(PKGS)
|
||||||
|
|
||||||
|
misspell:
|
||||||
|
$(GO) get github.com/client9/misspell/cmd/misspell
|
||||||
|
misspell \
|
||||||
|
-locale GB \
|
||||||
|
-error \
|
||||||
|
*.md *.go
|
||||||
|
|
||||||
|
unconvert:
|
||||||
|
$(GO) get github.com/mdempsky/unconvert
|
||||||
|
unconvert -v $(PKGS)
|
||||||
|
|
||||||
|
ineffassign:
|
||||||
|
$(GO) get github.com/gordonklaus/ineffassign
|
||||||
|
find $(SRCDIRS) -name '*.go' | xargs ineffassign
|
||||||
|
|
||||||
|
pedantic: check errcheck
|
||||||
|
|
||||||
|
unparam:
|
||||||
|
$(GO) get mvdan.cc/unparam
|
||||||
|
unparam ./...
|
||||||
|
|
||||||
|
errcheck:
|
||||||
|
$(GO) get github.com/kisielk/errcheck
|
||||||
|
errcheck $(PKGS)
|
||||||
|
|
||||||
|
gofmt:
|
||||||
|
@echo Checking code is gofmted
|
||||||
|
@test -z "$(shell gofmt -s -l -d -e $(SRCDIRS) | tee /dev/stderr)"
|
|
@ -0,0 +1,59 @@
|
||||||
|
# errors [![Travis-CI](https://travis-ci.org/pkg/errors.svg)](https://travis-ci.org/pkg/errors) [![AppVeyor](https://ci.appveyor.com/api/projects/status/b98mptawhudj53ep/branch/master?svg=true)](https://ci.appveyor.com/project/davecheney/errors/branch/master) [![GoDoc](https://godoc.org/github.com/pkg/errors?status.svg)](http://godoc.org/github.com/pkg/errors) [![Report card](https://goreportcard.com/badge/github.com/pkg/errors)](https://goreportcard.com/report/github.com/pkg/errors) [![Sourcegraph](https://sourcegraph.com/github.com/pkg/errors/-/badge.svg)](https://sourcegraph.com/github.com/pkg/errors?badge)
|
||||||
|
|
||||||
|
Package errors provides simple error handling primitives.
|
||||||
|
|
||||||
|
`go get github.com/pkg/errors`
|
||||||
|
|
||||||
|
The traditional error handling idiom in Go is roughly akin to
|
||||||
|
```go
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
```
|
||||||
|
which applied recursively up the call stack results in error reports without context or debugging information. The errors package allows programmers to add context to the failure path in their code in a way that does not destroy the original value of the error.
|
||||||
|
|
||||||
|
## Adding context to an error
|
||||||
|
|
||||||
|
The errors.Wrap function returns a new error that adds context to the original error. For example
|
||||||
|
```go
|
||||||
|
_, err := ioutil.ReadAll(r)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "read failed")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
## Retrieving the cause of an error
|
||||||
|
|
||||||
|
Using `errors.Wrap` constructs a stack of errors, adding context to the preceding error. Depending on the nature of the error it may be necessary to reverse the operation of errors.Wrap to retrieve the original error for inspection. Any error value which implements this interface can be inspected by `errors.Cause`.
|
||||||
|
```go
|
||||||
|
type causer interface {
|
||||||
|
Cause() error
|
||||||
|
}
|
||||||
|
```
|
||||||
|
`errors.Cause` will recursively retrieve the topmost error which does not implement `causer`, which is assumed to be the original cause. For example:
|
||||||
|
```go
|
||||||
|
switch err := errors.Cause(err).(type) {
|
||||||
|
case *MyError:
|
||||||
|
// handle specifically
|
||||||
|
default:
|
||||||
|
// unknown error
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
[Read the package documentation for more information](https://godoc.org/github.com/pkg/errors).
|
||||||
|
|
||||||
|
## Roadmap
|
||||||
|
|
||||||
|
With the upcoming [Go2 error proposals](https://go.googlesource.com/proposal/+/master/design/go2draft.md) this package is moving into maintenance mode. The roadmap for a 1.0 release is as follows:
|
||||||
|
|
||||||
|
- 0.9. Remove pre Go 1.9 and Go 1.10 support, address outstanding pull requests (if possible)
|
||||||
|
- 1.0. Final release.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Because of the Go2 errors changes, this package is not accepting proposals for new functionality. With that said, we welcome pull requests, bug fixes and issue reports.
|
||||||
|
|
||||||
|
Before sending a PR, please discuss your change by raising an issue.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
BSD-2-Clause
|
|
@ -0,0 +1,32 @@
|
||||||
|
version: build-{build}.{branch}
|
||||||
|
|
||||||
|
clone_folder: C:\gopath\src\github.com\pkg\errors
|
||||||
|
shallow_clone: true # for startup speed
|
||||||
|
|
||||||
|
environment:
|
||||||
|
GOPATH: C:\gopath
|
||||||
|
|
||||||
|
platform:
|
||||||
|
- x64
|
||||||
|
|
||||||
|
# http://www.appveyor.com/docs/installed-software
|
||||||
|
install:
|
||||||
|
# some helpful output for debugging builds
|
||||||
|
- go version
|
||||||
|
- go env
|
||||||
|
# pre-installed MinGW at C:\MinGW is 32bit only
|
||||||
|
# but MSYS2 at C:\msys64 has mingw64
|
||||||
|
- set PATH=C:\msys64\mingw64\bin;%PATH%
|
||||||
|
- gcc --version
|
||||||
|
- g++ --version
|
||||||
|
|
||||||
|
build_script:
|
||||||
|
- go install -v ./...
|
||||||
|
|
||||||
|
test_script:
|
||||||
|
- set PATH=C:\gopath\bin;%PATH%
|
||||||
|
- go test -v ./...
|
||||||
|
|
||||||
|
#artifacts:
|
||||||
|
# - path: '%GOPATH%\bin\*.exe'
|
||||||
|
deploy: off
|
|
@ -0,0 +1,288 @@
|
||||||
|
// Package errors provides simple error handling primitives.
|
||||||
|
//
|
||||||
|
// The traditional error handling idiom in Go is roughly akin to
|
||||||
|
//
|
||||||
|
// if err != nil {
|
||||||
|
// return err
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// which when applied recursively up the call stack results in error reports
|
||||||
|
// without context or debugging information. The errors package allows
|
||||||
|
// programmers to add context to the failure path in their code in a way
|
||||||
|
// that does not destroy the original value of the error.
|
||||||
|
//
|
||||||
|
// Adding context to an error
|
||||||
|
//
|
||||||
|
// The errors.Wrap function returns a new error that adds context to the
|
||||||
|
// original error by recording a stack trace at the point Wrap is called,
|
||||||
|
// together with the supplied message. For example
|
||||||
|
//
|
||||||
|
// _, err := ioutil.ReadAll(r)
|
||||||
|
// if err != nil {
|
||||||
|
// return errors.Wrap(err, "read failed")
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// If additional control is required, the errors.WithStack and
|
||||||
|
// errors.WithMessage functions destructure errors.Wrap into its component
|
||||||
|
// operations: annotating an error with a stack trace and with a message,
|
||||||
|
// respectively.
|
||||||
|
//
|
||||||
|
// Retrieving the cause of an error
|
||||||
|
//
|
||||||
|
// Using errors.Wrap constructs a stack of errors, adding context to the
|
||||||
|
// preceding error. Depending on the nature of the error it may be necessary
|
||||||
|
// to reverse the operation of errors.Wrap to retrieve the original error
|
||||||
|
// for inspection. Any error value which implements this interface
|
||||||
|
//
|
||||||
|
// type causer interface {
|
||||||
|
// Cause() error
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// can be inspected by errors.Cause. errors.Cause will recursively retrieve
|
||||||
|
// the topmost error that does not implement causer, which is assumed to be
|
||||||
|
// the original cause. For example:
|
||||||
|
//
|
||||||
|
// switch err := errors.Cause(err).(type) {
|
||||||
|
// case *MyError:
|
||||||
|
// // handle specifically
|
||||||
|
// default:
|
||||||
|
// // unknown error
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Although the causer interface is not exported by this package, it is
|
||||||
|
// considered a part of its stable public interface.
|
||||||
|
//
|
||||||
|
// Formatted printing of errors
|
||||||
|
//
|
||||||
|
// All error values returned from this package implement fmt.Formatter and can
|
||||||
|
// be formatted by the fmt package. The following verbs are supported:
|
||||||
|
//
|
||||||
|
// %s print the error. If the error has a Cause it will be
|
||||||
|
// printed recursively.
|
||||||
|
// %v see %s
|
||||||
|
// %+v extended format. Each Frame of the error's StackTrace will
|
||||||
|
// be printed in detail.
|
||||||
|
//
|
||||||
|
// Retrieving the stack trace of an error or wrapper
|
||||||
|
//
|
||||||
|
// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are
|
||||||
|
// invoked. This information can be retrieved with the following interface:
|
||||||
|
//
|
||||||
|
// type stackTracer interface {
|
||||||
|
// StackTrace() errors.StackTrace
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// The returned errors.StackTrace type is defined as
|
||||||
|
//
|
||||||
|
// type StackTrace []Frame
|
||||||
|
//
|
||||||
|
// The Frame type represents a call site in the stack trace. Frame supports
|
||||||
|
// the fmt.Formatter interface that can be used for printing information about
|
||||||
|
// the stack trace of this error. For example:
|
||||||
|
//
|
||||||
|
// if err, ok := err.(stackTracer); ok {
|
||||||
|
// for _, f := range err.StackTrace() {
|
||||||
|
// fmt.Printf("%+s:%d\n", f, f)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Although the stackTracer interface is not exported by this package, it is
|
||||||
|
// considered a part of its stable public interface.
|
||||||
|
//
|
||||||
|
// See the documentation for Frame.Format for more details.
|
||||||
|
package errors
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// New returns an error with the supplied message.
|
||||||
|
// New also records the stack trace at the point it was called.
|
||||||
|
func New(message string) error {
|
||||||
|
return &fundamental{
|
||||||
|
msg: message,
|
||||||
|
stack: callers(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Errorf formats according to a format specifier and returns the string
|
||||||
|
// as a value that satisfies error.
|
||||||
|
// Errorf also records the stack trace at the point it was called.
|
||||||
|
func Errorf(format string, args ...interface{}) error {
|
||||||
|
return &fundamental{
|
||||||
|
msg: fmt.Sprintf(format, args...),
|
||||||
|
stack: callers(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fundamental is an error that has a message and a stack, but no caller.
|
||||||
|
type fundamental struct {
|
||||||
|
msg string
|
||||||
|
*stack
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fundamental) Error() string { return f.msg }
|
||||||
|
|
||||||
|
func (f *fundamental) Format(s fmt.State, verb rune) {
|
||||||
|
switch verb {
|
||||||
|
case 'v':
|
||||||
|
if s.Flag('+') {
|
||||||
|
io.WriteString(s, f.msg)
|
||||||
|
f.stack.Format(s, verb)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case 's':
|
||||||
|
io.WriteString(s, f.msg)
|
||||||
|
case 'q':
|
||||||
|
fmt.Fprintf(s, "%q", f.msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithStack annotates err with a stack trace at the point WithStack was called.
|
||||||
|
// If err is nil, WithStack returns nil.
|
||||||
|
func WithStack(err error) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &withStack{
|
||||||
|
err,
|
||||||
|
callers(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type withStack struct {
|
||||||
|
error
|
||||||
|
*stack
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *withStack) Cause() error { return w.error }
|
||||||
|
|
||||||
|
// Unwrap provides compatibility for Go 1.13 error chains.
|
||||||
|
func (w *withStack) Unwrap() error { return w.error }
|
||||||
|
|
||||||
|
func (w *withStack) Format(s fmt.State, verb rune) {
|
||||||
|
switch verb {
|
||||||
|
case 'v':
|
||||||
|
if s.Flag('+') {
|
||||||
|
fmt.Fprintf(s, "%+v", w.Cause())
|
||||||
|
w.stack.Format(s, verb)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case 's':
|
||||||
|
io.WriteString(s, w.Error())
|
||||||
|
case 'q':
|
||||||
|
fmt.Fprintf(s, "%q", w.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap returns an error annotating err with a stack trace
|
||||||
|
// at the point Wrap is called, and the supplied message.
|
||||||
|
// If err is nil, Wrap returns nil.
|
||||||
|
func Wrap(err error, message string) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err = &withMessage{
|
||||||
|
cause: err,
|
||||||
|
msg: message,
|
||||||
|
}
|
||||||
|
return &withStack{
|
||||||
|
err,
|
||||||
|
callers(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrapf returns an error annotating err with a stack trace
|
||||||
|
// at the point Wrapf is called, and the format specifier.
|
||||||
|
// If err is nil, Wrapf returns nil.
|
||||||
|
func Wrapf(err error, format string, args ...interface{}) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err = &withMessage{
|
||||||
|
cause: err,
|
||||||
|
msg: fmt.Sprintf(format, args...),
|
||||||
|
}
|
||||||
|
return &withStack{
|
||||||
|
err,
|
||||||
|
callers(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMessage annotates err with a new message.
|
||||||
|
// If err is nil, WithMessage returns nil.
|
||||||
|
func WithMessage(err error, message string) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &withMessage{
|
||||||
|
cause: err,
|
||||||
|
msg: message,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithMessagef annotates err with the format specifier.
|
||||||
|
// If err is nil, WithMessagef returns nil.
|
||||||
|
func WithMessagef(err error, format string, args ...interface{}) error {
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &withMessage{
|
||||||
|
cause: err,
|
||||||
|
msg: fmt.Sprintf(format, args...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type withMessage struct {
|
||||||
|
cause error
|
||||||
|
msg string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() }
|
||||||
|
func (w *withMessage) Cause() error { return w.cause }
|
||||||
|
|
||||||
|
// Unwrap provides compatibility for Go 1.13 error chains.
|
||||||
|
func (w *withMessage) Unwrap() error { return w.cause }
|
||||||
|
|
||||||
|
func (w *withMessage) Format(s fmt.State, verb rune) {
|
||||||
|
switch verb {
|
||||||
|
case 'v':
|
||||||
|
if s.Flag('+') {
|
||||||
|
fmt.Fprintf(s, "%+v\n", w.Cause())
|
||||||
|
io.WriteString(s, w.msg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
case 's', 'q':
|
||||||
|
io.WriteString(s, w.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cause returns the underlying cause of the error, if possible.
|
||||||
|
// An error value has a cause if it implements the following
|
||||||
|
// interface:
|
||||||
|
//
|
||||||
|
// type causer interface {
|
||||||
|
// Cause() error
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// If the error does not implement Cause, the original error will
|
||||||
|
// be returned. If the error is nil, nil will be returned without further
|
||||||
|
// investigation.
|
||||||
|
func Cause(err error) error {
|
||||||
|
type causer interface {
|
||||||
|
Cause() error
|
||||||
|
}
|
||||||
|
|
||||||
|
for err != nil {
|
||||||
|
cause, ok := err.(causer)
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
err = cause.Cause()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,38 @@
|
||||||
|
// +build go1.13
|
||||||
|
|
||||||
|
package errors
|
||||||
|
|
||||||
|
import (
|
||||||
|
stderrors "errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Is reports whether any error in err's chain matches target.
|
||||||
|
//
|
||||||
|
// The chain consists of err itself followed by the sequence of errors obtained by
|
||||||
|
// repeatedly calling Unwrap.
|
||||||
|
//
|
||||||
|
// An error is considered to match a target if it is equal to that target or if
|
||||||
|
// it implements a method Is(error) bool such that Is(target) returns true.
|
||||||
|
func Is(err, target error) bool { return stderrors.Is(err, target) }
|
||||||
|
|
||||||
|
// As finds the first error in err's chain that matches target, and if so, sets
|
||||||
|
// target to that error value and returns true.
|
||||||
|
//
|
||||||
|
// The chain consists of err itself followed by the sequence of errors obtained by
|
||||||
|
// repeatedly calling Unwrap.
|
||||||
|
//
|
||||||
|
// An error matches target if the error's concrete value is assignable to the value
|
||||||
|
// pointed to by target, or if the error has a method As(interface{}) bool such that
|
||||||
|
// As(target) returns true. In the latter case, the As method is responsible for
|
||||||
|
// setting target.
|
||||||
|
//
|
||||||
|
// As will panic if target is not a non-nil pointer to either a type that implements
|
||||||
|
// error, or to any interface type. As returns false if err is nil.
|
||||||
|
func As(err error, target interface{}) bool { return stderrors.As(err, target) }
|
||||||
|
|
||||||
|
// Unwrap returns the result of calling the Unwrap method on err, if err's
|
||||||
|
// type contains an Unwrap method returning error.
|
||||||
|
// Otherwise, Unwrap returns nil.
|
||||||
|
func Unwrap(err error) error {
|
||||||
|
return stderrors.Unwrap(err)
|
||||||
|
}
|
|
@ -0,0 +1,177 @@
|
||||||
|
package errors
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"path"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Frame represents a program counter inside a stack frame.
|
||||||
|
// For historical reasons if Frame is interpreted as a uintptr
|
||||||
|
// its value represents the program counter + 1.
|
||||||
|
type Frame uintptr
|
||||||
|
|
||||||
|
// pc returns the program counter for this frame;
|
||||||
|
// multiple frames may have the same PC value.
|
||||||
|
func (f Frame) pc() uintptr { return uintptr(f) - 1 }
|
||||||
|
|
||||||
|
// file returns the full path to the file that contains the
|
||||||
|
// function for this Frame's pc.
|
||||||
|
func (f Frame) file() string {
|
||||||
|
fn := runtime.FuncForPC(f.pc())
|
||||||
|
if fn == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
file, _ := fn.FileLine(f.pc())
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
|
||||||
|
// line returns the line number of source code of the
|
||||||
|
// function for this Frame's pc.
|
||||||
|
func (f Frame) line() int {
|
||||||
|
fn := runtime.FuncForPC(f.pc())
|
||||||
|
if fn == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
_, line := fn.FileLine(f.pc())
|
||||||
|
return line
|
||||||
|
}
|
||||||
|
|
||||||
|
// name returns the name of this function, if known.
|
||||||
|
func (f Frame) name() string {
|
||||||
|
fn := runtime.FuncForPC(f.pc())
|
||||||
|
if fn == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
return fn.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format formats the frame according to the fmt.Formatter interface.
|
||||||
|
//
|
||||||
|
// %s source file
|
||||||
|
// %d source line
|
||||||
|
// %n function name
|
||||||
|
// %v equivalent to %s:%d
|
||||||
|
//
|
||||||
|
// Format accepts flags that alter the printing of some verbs, as follows:
|
||||||
|
//
|
||||||
|
// %+s function name and path of source file relative to the compile time
|
||||||
|
// GOPATH separated by \n\t (<funcname>\n\t<path>)
|
||||||
|
// %+v equivalent to %+s:%d
|
||||||
|
func (f Frame) Format(s fmt.State, verb rune) {
|
||||||
|
switch verb {
|
||||||
|
case 's':
|
||||||
|
switch {
|
||||||
|
case s.Flag('+'):
|
||||||
|
io.WriteString(s, f.name())
|
||||||
|
io.WriteString(s, "\n\t")
|
||||||
|
io.WriteString(s, f.file())
|
||||||
|
default:
|
||||||
|
io.WriteString(s, path.Base(f.file()))
|
||||||
|
}
|
||||||
|
case 'd':
|
||||||
|
io.WriteString(s, strconv.Itoa(f.line()))
|
||||||
|
case 'n':
|
||||||
|
io.WriteString(s, funcname(f.name()))
|
||||||
|
case 'v':
|
||||||
|
f.Format(s, 's')
|
||||||
|
io.WriteString(s, ":")
|
||||||
|
f.Format(s, 'd')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalText formats a stacktrace Frame as a text string. The output is the
|
||||||
|
// same as that of fmt.Sprintf("%+v", f), but without newlines or tabs.
|
||||||
|
func (f Frame) MarshalText() ([]byte, error) {
|
||||||
|
name := f.name()
|
||||||
|
if name == "unknown" {
|
||||||
|
return []byte(name), nil
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf("%s %s:%d", name, f.file(), f.line())), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StackTrace is stack of Frames from innermost (newest) to outermost (oldest).
|
||||||
|
type StackTrace []Frame
|
||||||
|
|
||||||
|
// Format formats the stack of Frames according to the fmt.Formatter interface.
|
||||||
|
//
|
||||||
|
// %s lists source files for each Frame in the stack
|
||||||
|
// %v lists the source file and line number for each Frame in the stack
|
||||||
|
//
|
||||||
|
// Format accepts flags that alter the printing of some verbs, as follows:
|
||||||
|
//
|
||||||
|
// %+v Prints filename, function, and line number for each Frame in the stack.
|
||||||
|
func (st StackTrace) Format(s fmt.State, verb rune) {
|
||||||
|
switch verb {
|
||||||
|
case 'v':
|
||||||
|
switch {
|
||||||
|
case s.Flag('+'):
|
||||||
|
for _, f := range st {
|
||||||
|
io.WriteString(s, "\n")
|
||||||
|
f.Format(s, verb)
|
||||||
|
}
|
||||||
|
case s.Flag('#'):
|
||||||
|
fmt.Fprintf(s, "%#v", []Frame(st))
|
||||||
|
default:
|
||||||
|
st.formatSlice(s, verb)
|
||||||
|
}
|
||||||
|
case 's':
|
||||||
|
st.formatSlice(s, verb)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatSlice will format this StackTrace into the given buffer as a slice of
|
||||||
|
// Frame, only valid when called with '%s' or '%v'.
|
||||||
|
func (st StackTrace) formatSlice(s fmt.State, verb rune) {
|
||||||
|
io.WriteString(s, "[")
|
||||||
|
for i, f := range st {
|
||||||
|
if i > 0 {
|
||||||
|
io.WriteString(s, " ")
|
||||||
|
}
|
||||||
|
f.Format(s, verb)
|
||||||
|
}
|
||||||
|
io.WriteString(s, "]")
|
||||||
|
}
|
||||||
|
|
||||||
|
// stack represents a stack of program counters.
|
||||||
|
type stack []uintptr
|
||||||
|
|
||||||
|
func (s *stack) Format(st fmt.State, verb rune) {
|
||||||
|
switch verb {
|
||||||
|
case 'v':
|
||||||
|
switch {
|
||||||
|
case st.Flag('+'):
|
||||||
|
for _, pc := range *s {
|
||||||
|
f := Frame(pc)
|
||||||
|
fmt.Fprintf(st, "\n%+v", f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stack) StackTrace() StackTrace {
|
||||||
|
f := make([]Frame, len(*s))
|
||||||
|
for i := 0; i < len(f); i++ {
|
||||||
|
f[i] = Frame((*s)[i])
|
||||||
|
}
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
func callers() *stack {
|
||||||
|
const depth = 32
|
||||||
|
var pcs [depth]uintptr
|
||||||
|
n := runtime.Callers(3, pcs[:])
|
||||||
|
var st stack = pcs[0:n]
|
||||||
|
return &st
|
||||||
|
}
|
||||||
|
|
||||||
|
// funcname removes the path prefix component of a function's name reported by func.Name().
|
||||||
|
func funcname(name string) string {
|
||||||
|
i := strings.LastIndex(name, "/")
|
||||||
|
name = name[i+1:]
|
||||||
|
i = strings.Index(name, ".")
|
||||||
|
return name[i+1:]
|
||||||
|
}
|
|
@ -0,0 +1,5 @@
|
||||||
|
# Circle CI directory
|
||||||
|
.circleci
|
||||||
|
|
||||||
|
# Example directory
|
||||||
|
examples
|
|
@ -0,0 +1,25 @@
|
||||||
|
root = true
|
||||||
|
|
||||||
|
[*]
|
||||||
|
end_of_line = lf
|
||||||
|
indent_size = 4
|
||||||
|
indent_style = space
|
||||||
|
insert_final_newline = true
|
||||||
|
trim_trailing_whitespace = true
|
||||||
|
insert_final_newline = true
|
||||||
|
charset = utf-8
|
||||||
|
|
||||||
|
[*.{yml,yaml}]
|
||||||
|
indent_size = 2
|
||||||
|
|
||||||
|
[*.go]
|
||||||
|
indent_size = 8
|
||||||
|
indent_style = tab
|
||||||
|
|
||||||
|
[*.json]
|
||||||
|
indent_size = 4
|
||||||
|
indent_style = space
|
||||||
|
|
||||||
|
[Makefile]
|
||||||
|
indent_style = tab
|
||||||
|
indent_size = 4
|
|
@ -0,0 +1,2 @@
|
||||||
|
/vendor
|
||||||
|
.idea
|
|
@ -0,0 +1,79 @@
|
||||||
|
run:
|
||||||
|
concurrency: 4
|
||||||
|
deadline: 1m
|
||||||
|
issues-exit-code: 1
|
||||||
|
tests: true
|
||||||
|
|
||||||
|
|
||||||
|
output:
|
||||||
|
format: colored-line-number
|
||||||
|
print-issued-lines: true
|
||||||
|
print-linter-name: true
|
||||||
|
|
||||||
|
|
||||||
|
linters-settings:
|
||||||
|
errcheck:
|
||||||
|
check-type-assertions: false
|
||||||
|
check-blank: false
|
||||||
|
govet:
|
||||||
|
check-shadowing: false
|
||||||
|
use-installed-packages: false
|
||||||
|
golint:
|
||||||
|
min-confidence: 0.8
|
||||||
|
gofmt:
|
||||||
|
simplify: true
|
||||||
|
gocyclo:
|
||||||
|
min-complexity: 10
|
||||||
|
maligned:
|
||||||
|
suggest-new: true
|
||||||
|
dupl:
|
||||||
|
threshold: 80
|
||||||
|
goconst:
|
||||||
|
min-len: 3
|
||||||
|
min-occurrences: 3
|
||||||
|
misspell:
|
||||||
|
locale: US
|
||||||
|
lll:
|
||||||
|
line-length: 140
|
||||||
|
unused:
|
||||||
|
check-exported: false
|
||||||
|
unparam:
|
||||||
|
algo: cha
|
||||||
|
check-exported: false
|
||||||
|
nakedret:
|
||||||
|
max-func-lines: 30
|
||||||
|
|
||||||
|
linters:
|
||||||
|
enable:
|
||||||
|
- megacheck
|
||||||
|
- govet
|
||||||
|
- errcheck
|
||||||
|
- gas
|
||||||
|
- structcheck
|
||||||
|
- varcheck
|
||||||
|
- ineffassign
|
||||||
|
- deadcode
|
||||||
|
- typecheck
|
||||||
|
- unconvert
|
||||||
|
- gocyclo
|
||||||
|
- gofmt
|
||||||
|
- misspell
|
||||||
|
- lll
|
||||||
|
- nakedret
|
||||||
|
enable-all: false
|
||||||
|
disable:
|
||||||
|
- depguard
|
||||||
|
- prealloc
|
||||||
|
- dupl
|
||||||
|
- maligned
|
||||||
|
disable-all: false
|
||||||
|
|
||||||
|
|
||||||
|
issues:
|
||||||
|
exclude-use-default: false
|
||||||
|
max-per-linter: 1024
|
||||||
|
max-same: 1024
|
||||||
|
exclude:
|
||||||
|
- "G304"
|
||||||
|
- "G101"
|
||||||
|
- "G104"
|
|
@ -0,0 +1,5 @@
|
||||||
|
Primary contributors:
|
||||||
|
|
||||||
|
Gilles FABIO <gilles@ulule.com>
|
||||||
|
Florent MESSA <florent@ulule.com>
|
||||||
|
Thomas LE ROUX <thomas@leroux.io>
|
|
@ -0,0 +1,21 @@
|
||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2015-2018 Ulule
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
|
@ -0,0 +1,7 @@
|
||||||
|
.PHONY: test lint
|
||||||
|
|
||||||
|
test:
|
||||||
|
@(scripts/test)
|
||||||
|
|
||||||
|
lint:
|
||||||
|
@(scripts/lint)
|
|
@ -0,0 +1,255 @@
|
||||||
|
# Limiter
|
||||||
|
|
||||||
|
[![Documentation][godoc-img]][godoc-url]
|
||||||
|
![License][license-img]
|
||||||
|
[![Build Status][circle-img]][circle-url]
|
||||||
|
[![Go Report Card][goreport-img]][goreport-url]
|
||||||
|
|
||||||
|
_Dead simple rate limit middleware for Go._
|
||||||
|
|
||||||
|
- Simple API
|
||||||
|
- "Store" approach for backend
|
||||||
|
- Redis support (but not tied too)
|
||||||
|
- Middlewares: HTTP, [FastHTTP][6] and [Gin][4]
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Using [Go Modules](https://github.com/golang/go/wiki/Modules)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ go get github.com/ulule/limiter/v3@v3.10.0
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
In five steps:
|
||||||
|
|
||||||
|
- Create a `limiter.Rate` instance _(the number of requests per period)_
|
||||||
|
- Create a `limiter.Store` instance _(see [Redis](https://github.com/ulule/limiter/blob/master/drivers/store/redis/store.go) or [In-Memory](https://github.com/ulule/limiter/blob/master/drivers/store/memory/store.go))_
|
||||||
|
- Create a `limiter.Limiter` instance that takes store and rate instances as arguments
|
||||||
|
- Create a middleware instance using the middleware of your choice
|
||||||
|
- Give the limiter instance to your middleware initializer
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Create a rate with the given limit (number of requests) for the given
|
||||||
|
// period (a time.Duration of your choice).
|
||||||
|
import "github.com/ulule/limiter/v3"
|
||||||
|
|
||||||
|
rate := limiter.Rate{
|
||||||
|
Period: 1 * time.Hour,
|
||||||
|
Limit: 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
// You can also use the simplified format "<limit>-<period>"", with the given
|
||||||
|
// periods:
|
||||||
|
//
|
||||||
|
// * "S": second
|
||||||
|
// * "M": minute
|
||||||
|
// * "H": hour
|
||||||
|
// * "D": day
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
//
|
||||||
|
// * 5 reqs/second: "5-S"
|
||||||
|
// * 10 reqs/minute: "10-M"
|
||||||
|
// * 1000 reqs/hour: "1000-H"
|
||||||
|
// * 2000 reqs/day: "2000-D"
|
||||||
|
//
|
||||||
|
rate, err := limiter.NewRateFromFormatted("1000-H")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Then, create a store. Here, we use the bundled Redis store. Any store
|
||||||
|
// compliant to limiter.Store interface will do the job. The defaults are
|
||||||
|
// "limiter" as Redis key prefix and a maximum of 3 retries for the key under
|
||||||
|
// race condition.
|
||||||
|
import "github.com/ulule/limiter/v3/drivers/store/redis"
|
||||||
|
|
||||||
|
store, err := redis.NewStore(client)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Alternatively, you can pass options to the store with the "WithOptions"
|
||||||
|
// function. For example, for Redis store:
|
||||||
|
import "github.com/ulule/limiter/v3/drivers/store/redis"
|
||||||
|
|
||||||
|
store, err := redis.NewStoreWithOptions(pool, limiter.StoreOptions{
|
||||||
|
Prefix: "your_own_prefix",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Or use a in-memory store with a goroutine which clears expired keys.
|
||||||
|
import "github.com/ulule/limiter/v3/drivers/store/memory"
|
||||||
|
|
||||||
|
store := memory.NewStore()
|
||||||
|
|
||||||
|
// Then, create the limiter instance which takes the store and the rate as arguments.
|
||||||
|
// Now, you can give this instance to any supported middleware.
|
||||||
|
instance := limiter.New(store, rate)
|
||||||
|
|
||||||
|
// Alternatively, you can pass options to the limiter instance with several options.
|
||||||
|
instance := limiter.New(store, rate, limiter.WithClientIPHeader("True-Client-IP"), limiter.WithIPv6Mask(mask))
|
||||||
|
|
||||||
|
// Finally, give the limiter instance to your middleware initializer.
|
||||||
|
import "github.com/ulule/limiter/v3/drivers/middleware/stdlib"
|
||||||
|
|
||||||
|
middleware := stdlib.NewMiddleware(instance)
|
||||||
|
```
|
||||||
|
|
||||||
|
See middleware examples:
|
||||||
|
|
||||||
|
- [HTTP](https://github.com/ulule/limiter-examples/tree/master/http/main.go)
|
||||||
|
- [Gin](https://github.com/ulule/limiter-examples/tree/master/gin/main.go)
|
||||||
|
- [Beego](https://github.com/ulule/limiter-examples/blob/master//beego/main.go)
|
||||||
|
- [Chi](https://github.com/ulule/limiter-examples/tree/master/chi/main.go)
|
||||||
|
- [Echo](https://github.com/ulule/limiter-examples/tree/master/echo/main.go)
|
||||||
|
- [Fasthttp](https://github.com/ulule/limiter-examples/tree/master/fasthttp/main.go)
|
||||||
|
|
||||||
|
## How it works
|
||||||
|
|
||||||
|
The ip address of the request is used as a key in the store.
|
||||||
|
|
||||||
|
If the key does not exist in the store we set a default
|
||||||
|
value with an expiration period.
|
||||||
|
|
||||||
|
You will find two stores:
|
||||||
|
|
||||||
|
- Redis: rely on [TTL](http://redis.io/commands/ttl) and incrementing the rate limit on each request.
|
||||||
|
- In-Memory: rely on a fork of [go-cache](https://github.com/patrickmn/go-cache) with a goroutine to clear expired keys using a default interval.
|
||||||
|
|
||||||
|
When the limit is reached, a `429` HTTP status code is sent.
|
||||||
|
|
||||||
|
## Limiter behind a reverse proxy
|
||||||
|
|
||||||
|
### Introduction
|
||||||
|
|
||||||
|
If your limiter is behind a reverse proxy, it could be difficult to obtain the "real" client IP.
|
||||||
|
|
||||||
|
Some reverse proxies, like AWS ALB, lets all header values through that it doesn't set itself.
|
||||||
|
Like for example, `True-Client-IP` and `X-Real-IP`.
|
||||||
|
Similarly, `X-Forwarded-For` is a list of comma-separated IPs that gets appended to by each traversed proxy.
|
||||||
|
The idea is that the first IP _(added by the first proxy)_ is the true client IP. Each subsequent IP is another proxy along the path.
|
||||||
|
|
||||||
|
An attacker can spoof either of those headers, which could be reported as a client IP.
|
||||||
|
|
||||||
|
By default, limiter doesn't trust any of those headers: you have to explicitly enable them in order to use them.
|
||||||
|
If you enable them, **you must always be aware** that any header added by any _(reverse)_ proxy not controlled
|
||||||
|
by you **are completely unreliable.**
|
||||||
|
|
||||||
|
### X-Forwarded-For
|
||||||
|
|
||||||
|
For example, if you make this request to your load balancer:
|
||||||
|
```bash
|
||||||
|
curl -X POST https://example.com/login -H "X-Forwarded-For: 1.2.3.4, 11.22.33.44"
|
||||||
|
```
|
||||||
|
|
||||||
|
And your server behind the load balancer obtain this:
|
||||||
|
```
|
||||||
|
X-Forwarded-For: 1.2.3.4, 11.22.33.44, <actual client IP>
|
||||||
|
```
|
||||||
|
|
||||||
|
That's mean you can't use `X-Forwarded-For` header, because it's **unreliable** and **untrustworthy**.
|
||||||
|
So keep `TrustForwardHeader` disabled in your limiter option.
|
||||||
|
|
||||||
|
However, if you have configured your reverse proxy to always remove/overwrite `X-Forwarded-For` and/or `X-Real-IP` headers
|
||||||
|
so that if you execute this _(same)_ request:
|
||||||
|
```bash
|
||||||
|
curl -X POST https://example.com/login -H "X-Forwarded-For: 1.2.3.4, 11.22.33.44"
|
||||||
|
```
|
||||||
|
|
||||||
|
And your server behind the load balancer obtain this:
|
||||||
|
```
|
||||||
|
X-Forwarded-For: <actual client IP>
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, you can enable `TrustForwardHeader` in your limiter option.
|
||||||
|
|
||||||
|
### Custom header
|
||||||
|
|
||||||
|
Many CDN and Cloud providers add a custom header to define the client IP. Like for example, this non exhaustive list:
|
||||||
|
|
||||||
|
* `Fastly-Client-IP` from Fastly
|
||||||
|
* `CF-Connecting-IP` from Cloudflare
|
||||||
|
* `X-Azure-ClientIP` from Azure
|
||||||
|
|
||||||
|
You can use these headers using `ClientIPHeader` in your limiter option.
|
||||||
|
|
||||||
|
### None of the above
|
||||||
|
|
||||||
|
If none of the above solution are working, please use a custom `KeyGetter` in your middleware.
|
||||||
|
|
||||||
|
You can use this excellent article to help you define the best strategy depending on your network topology and your security need:
|
||||||
|
https://adam-p.ca/blog/2022/03/x-forwarded-for/
|
||||||
|
|
||||||
|
If you have any idea/suggestions on how we could simplify this steps, don't hesitate to raise an issue.
|
||||||
|
We would like some feedback on how we could implement this steps in the Limiter API.
|
||||||
|
|
||||||
|
Thank you.
|
||||||
|
|
||||||
|
## Why Yet Another Package
|
||||||
|
|
||||||
|
You could ask us: why yet another rate limit package?
|
||||||
|
|
||||||
|
Because existing packages did not suit our needs.
|
||||||
|
|
||||||
|
We tried a lot of alternatives:
|
||||||
|
|
||||||
|
1. [Throttled][1]. This package uses the generic cell-rate algorithm. To cite the
|
||||||
|
documentation: _"The algorithm has been slightly modified from its usual form to
|
||||||
|
support limiting with an additional quantity parameter, such as for limiting the
|
||||||
|
number of bytes uploaded"_. It is brillant in term of algorithm but
|
||||||
|
documentation is quite unclear at the moment, we don't need _burst_ feature for
|
||||||
|
now, impossible to get a correct `After-Retry` (when limit exceeds, we can still
|
||||||
|
make a few requests, because of the max burst) and it only supports `http.Handler`
|
||||||
|
middleware (we use [Gin][4]). Currently, we only need to return `429`
|
||||||
|
and `X-Ratelimit-*` headers for `n reqs/duration`.
|
||||||
|
|
||||||
|
2. [Speedbump][3]. Good package but maybe too lightweight. No `Reset` support,
|
||||||
|
only one middleware for [Gin][4] framework and too Redis-coupled. We rather
|
||||||
|
prefer to use a "store" approach.
|
||||||
|
|
||||||
|
3. [Tollbooth][5]. Good one too but does both too much and too little. It limits by
|
||||||
|
remote IP, path, methods, custom headers and basic auth usernames... but does not
|
||||||
|
provide any Redis support (only _in-memory_) and a ready-to-go middleware that sets
|
||||||
|
`X-Ratelimit-*` headers. `tollbooth.LimitByRequest(limiter, r)` only returns an HTTP
|
||||||
|
code.
|
||||||
|
|
||||||
|
4. [ratelimit][2]. Probably the closer to our needs but, once again, too
|
||||||
|
lightweight, no middleware available and not active (last commit was in August
|
||||||
|
2014). Some parts of code (Redis) comes from this project. It should deserve much
|
||||||
|
more love.
|
||||||
|
|
||||||
|
There are other many packages on GitHub but most are either too lightweight, too
|
||||||
|
old (only support old Go versions) or unmaintained. So that's why we decided to
|
||||||
|
create yet another one.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
- Ping us on twitter:
|
||||||
|
- [@oibafsellig](https://twitter.com/oibafsellig)
|
||||||
|
- [@thoas](https://twitter.com/thoas)
|
||||||
|
- [@novln\_](https://twitter.com/novln_)
|
||||||
|
- Fork the [project](https://github.com/ulule/limiter)
|
||||||
|
- Fix [bugs](https://github.com/ulule/limiter/issues)
|
||||||
|
|
||||||
|
Don't hesitate ;)
|
||||||
|
|
||||||
|
[1]: https://github.com/throttled/throttled
|
||||||
|
[2]: https://github.com/r8k/ratelimit
|
||||||
|
[3]: https://github.com/etcinit/speedbump
|
||||||
|
[4]: https://github.com/gin-gonic/gin
|
||||||
|
[5]: https://github.com/didip/tollbooth
|
||||||
|
[6]: https://github.com/valyala/fasthttp
|
||||||
|
[godoc-url]: https://pkg.go.dev/github.com/ulule/limiter/v3
|
||||||
|
[godoc-img]: https://pkg.go.dev/badge/github.com/ulule/limiter/v3
|
||||||
|
[license-img]: https://img.shields.io/badge/license-MIT-blue.svg
|
||||||
|
[goreport-url]: https://goreportcard.com/report/github.com/ulule/limiter
|
||||||
|
[goreport-img]: https://goreportcard.com/badge/github.com/ulule/limiter
|
||||||
|
[circle-url]: https://circleci.com/gh/ulule/limiter/tree/master
|
||||||
|
[circle-img]: https://circleci.com/gh/ulule/limiter.svg?style=shield&circle-token=baf62ec320dd871b3a4a7e67fa99530fbc877c99
|
|
@ -0,0 +1,15 @@
|
||||||
|
package limiter
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultPrefix is the default prefix to use for the key in the store.
|
||||||
|
DefaultPrefix = "limiter"
|
||||||
|
|
||||||
|
// DefaultMaxRetry is the default maximum number of key retries under
|
||||||
|
// race condition (mainly used with database-based stores).
|
||||||
|
DefaultMaxRetry = 3
|
||||||
|
|
||||||
|
// DefaultCleanUpInterval is the default time duration for cleanup.
|
||||||
|
DefaultCleanUpInterval = 30 * time.Second
|
||||||
|
)
|
65
vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/middleware.go
generated
vendored
Normal file
65
vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/middleware.go
generated
vendored
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
package gin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"github.com/ulule/limiter/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Middleware is the middleware for gin.
|
||||||
|
type Middleware struct {
|
||||||
|
Limiter *limiter.Limiter
|
||||||
|
OnError ErrorHandler
|
||||||
|
OnLimitReached LimitReachedHandler
|
||||||
|
KeyGetter KeyGetter
|
||||||
|
ExcludedKey func(string) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMiddleware return a new instance of a gin middleware.
|
||||||
|
func NewMiddleware(limiter *limiter.Limiter, options ...Option) gin.HandlerFunc {
|
||||||
|
middleware := &Middleware{
|
||||||
|
Limiter: limiter,
|
||||||
|
OnError: DefaultErrorHandler,
|
||||||
|
OnLimitReached: DefaultLimitReachedHandler,
|
||||||
|
KeyGetter: DefaultKeyGetter,
|
||||||
|
ExcludedKey: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, option := range options {
|
||||||
|
option.apply(middleware)
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(ctx *gin.Context) {
|
||||||
|
middleware.Handle(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle gin request.
|
||||||
|
func (middleware *Middleware) Handle(c *gin.Context) {
|
||||||
|
key := middleware.KeyGetter(c)
|
||||||
|
if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
context, err := middleware.Limiter.Get(c, key)
|
||||||
|
if err != nil {
|
||||||
|
middleware.OnError(c, err)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Header("X-RateLimit-Limit", strconv.FormatInt(context.Limit, 10))
|
||||||
|
c.Header("X-RateLimit-Remaining", strconv.FormatInt(context.Remaining, 10))
|
||||||
|
c.Header("X-RateLimit-Reset", strconv.FormatInt(context.Reset, 10))
|
||||||
|
|
||||||
|
if context.Reached {
|
||||||
|
middleware.OnLimitReached(c)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Next()
|
||||||
|
}
|
71
vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/options.go
generated
vendored
Normal file
71
vendor/github.com/ulule/limiter/v3/drivers/middleware/gin/options.go
generated
vendored
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
package gin
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Option is used to define Middleware configuration.
|
||||||
|
type Option interface {
|
||||||
|
apply(*Middleware)
|
||||||
|
}
|
||||||
|
|
||||||
|
type option func(*Middleware)
|
||||||
|
|
||||||
|
func (o option) apply(middleware *Middleware) {
|
||||||
|
o(middleware)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorHandler is an handler used to inform when an error has occurred.
|
||||||
|
type ErrorHandler func(c *gin.Context, err error)
|
||||||
|
|
||||||
|
// WithErrorHandler will configure the Middleware to use the given ErrorHandler.
|
||||||
|
func WithErrorHandler(handler ErrorHandler) Option {
|
||||||
|
return option(func(middleware *Middleware) {
|
||||||
|
middleware.OnError = handler
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultErrorHandler is the default ErrorHandler used by a new Middleware.
|
||||||
|
func DefaultErrorHandler(c *gin.Context, err error) {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LimitReachedHandler is an handler used to inform when the limit has exceeded.
|
||||||
|
type LimitReachedHandler func(c *gin.Context)
|
||||||
|
|
||||||
|
// WithLimitReachedHandler will configure the Middleware to use the given LimitReachedHandler.
|
||||||
|
func WithLimitReachedHandler(handler LimitReachedHandler) Option {
|
||||||
|
return option(func(middleware *Middleware) {
|
||||||
|
middleware.OnLimitReached = handler
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultLimitReachedHandler is the default LimitReachedHandler used by a new Middleware.
|
||||||
|
func DefaultLimitReachedHandler(c *gin.Context) {
|
||||||
|
c.String(http.StatusTooManyRequests, "Limit exceeded")
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyGetter will define the rate limiter key given the gin Context.
|
||||||
|
type KeyGetter func(c *gin.Context) string
|
||||||
|
|
||||||
|
// WithKeyGetter will configure the Middleware to use the given KeyGetter.
|
||||||
|
func WithKeyGetter(handler KeyGetter) Option {
|
||||||
|
return option(func(middleware *Middleware) {
|
||||||
|
middleware.KeyGetter = handler
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultKeyGetter is the default KeyGetter used by a new Middleware.
|
||||||
|
// It returns the Client IP address.
|
||||||
|
func DefaultKeyGetter(c *gin.Context) string {
|
||||||
|
return c.ClientIP()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithExcludedKey will configure the Middleware to ignore key(s) using the given function.
|
||||||
|
func WithExcludedKey(handler func(string) bool) Option {
|
||||||
|
return option(func(middleware *Middleware) {
|
||||||
|
middleware.ExcludedKey = handler
|
||||||
|
})
|
||||||
|
}
|
28
vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go
generated
vendored
Normal file
28
vendor/github.com/ulule/limiter/v3/drivers/store/common/context.go
generated
vendored
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ulule/limiter/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetContextFromState generate a new limiter.Context from given state.
|
||||||
|
func GetContextFromState(now time.Time, rate limiter.Rate, expiration time.Time, count int64) limiter.Context {
|
||||||
|
limit := rate.Limit
|
||||||
|
remaining := int64(0)
|
||||||
|
reached := true
|
||||||
|
|
||||||
|
if count <= limit {
|
||||||
|
remaining = limit - count
|
||||||
|
reached = false
|
||||||
|
}
|
||||||
|
|
||||||
|
reset := expiration.Unix()
|
||||||
|
|
||||||
|
return limiter.Context{
|
||||||
|
Limit: limit,
|
||||||
|
Remaining: remaining,
|
||||||
|
Reset: reset,
|
||||||
|
Reached: reached,
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,240 @@
|
||||||
|
package memory
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Forked from https://github.com/patrickmn/go-cache
|
||||||
|
|
||||||
|
// CacheWrapper is used to ensure that the underlying cleaner goroutine used to clean expired keys will not prevent
|
||||||
|
// Cache from being garbage collected.
|
||||||
|
type CacheWrapper struct {
|
||||||
|
*Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
// A cleaner will periodically delete expired keys from cache.
|
||||||
|
type cleaner struct {
|
||||||
|
interval time.Duration
|
||||||
|
stop chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run will periodically delete expired keys from given cache until GC notify that it should stop.
|
||||||
|
func (cleaner *cleaner) Run(cache *Cache) {
|
||||||
|
ticker := time.NewTicker(cleaner.interval)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
cache.Clean()
|
||||||
|
case <-cleaner.stop:
|
||||||
|
ticker.Stop()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopCleaner is a callback from GC used to stop cleaner goroutine.
|
||||||
|
func stopCleaner(wrapper *CacheWrapper) {
|
||||||
|
wrapper.cleaner.stop <- true
|
||||||
|
wrapper.cleaner = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// startCleaner will start a cleaner goroutine for given cache.
|
||||||
|
func startCleaner(cache *Cache, interval time.Duration) {
|
||||||
|
cleaner := &cleaner{
|
||||||
|
interval: interval,
|
||||||
|
stop: make(chan bool),
|
||||||
|
}
|
||||||
|
|
||||||
|
cache.cleaner = cleaner
|
||||||
|
go cleaner.Run(cache)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Counter is a simple counter with an expiration.
|
||||||
|
type Counter struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
value int64
|
||||||
|
expiration int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value returns the counter current value.
|
||||||
|
func (counter *Counter) Value() int64 {
|
||||||
|
counter.mutex.RLock()
|
||||||
|
defer counter.mutex.RUnlock()
|
||||||
|
return counter.value
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expiration returns the counter expiration.
|
||||||
|
func (counter *Counter) Expiration() int64 {
|
||||||
|
counter.mutex.RLock()
|
||||||
|
defer counter.mutex.RUnlock()
|
||||||
|
return counter.expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expired returns true if the counter has expired.
|
||||||
|
func (counter *Counter) Expired() bool {
|
||||||
|
counter.mutex.RLock()
|
||||||
|
defer counter.mutex.RUnlock()
|
||||||
|
|
||||||
|
return counter.expiration == 0 || time.Now().UnixNano() > counter.expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load returns the value and the expiration of this counter.
|
||||||
|
// If the counter is expired, it will use the given expiration.
|
||||||
|
func (counter *Counter) Load(expiration int64) (int64, int64) {
|
||||||
|
counter.mutex.RLock()
|
||||||
|
defer counter.mutex.RUnlock()
|
||||||
|
|
||||||
|
if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration {
|
||||||
|
return 0, expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
return counter.value, counter.expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment increments given value on this counter.
|
||||||
|
// If the counter is expired, it will use the given expiration.
|
||||||
|
// It returns its current value and expiration.
|
||||||
|
func (counter *Counter) Increment(value int64, expiration int64) (int64, int64) {
|
||||||
|
counter.mutex.Lock()
|
||||||
|
defer counter.mutex.Unlock()
|
||||||
|
|
||||||
|
if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration {
|
||||||
|
counter.value = value
|
||||||
|
counter.expiration = expiration
|
||||||
|
return counter.value, counter.expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
counter.value += value
|
||||||
|
return counter.value, counter.expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cache contains a collection of counters.
|
||||||
|
type Cache struct {
|
||||||
|
counters sync.Map
|
||||||
|
cleaner *cleaner
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCache returns a new cache.
|
||||||
|
func NewCache(cleanInterval time.Duration) *CacheWrapper {
|
||||||
|
|
||||||
|
cache := &Cache{}
|
||||||
|
wrapper := &CacheWrapper{Cache: cache}
|
||||||
|
|
||||||
|
if cleanInterval > 0 {
|
||||||
|
startCleaner(cache, cleanInterval)
|
||||||
|
runtime.SetFinalizer(wrapper, stopCleaner)
|
||||||
|
}
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadOrStore returns the existing counter for the key if present.
|
||||||
|
// Otherwise, it stores and returns the given counter.
|
||||||
|
// The loaded result is true if the counter was loaded, false if stored.
|
||||||
|
func (cache *Cache) LoadOrStore(key string, counter *Counter) (*Counter, bool) {
|
||||||
|
val, loaded := cache.counters.LoadOrStore(key, counter)
|
||||||
|
if val == nil {
|
||||||
|
return counter, false
|
||||||
|
}
|
||||||
|
|
||||||
|
actual := val.(*Counter)
|
||||||
|
return actual, loaded
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load returns the counter stored in the map for a key, or nil if no counter is present.
|
||||||
|
// The ok result indicates whether counter was found in the map.
|
||||||
|
func (cache *Cache) Load(key string) (*Counter, bool) {
|
||||||
|
val, ok := cache.counters.Load(key)
|
||||||
|
if val == nil || !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
actual := val.(*Counter)
|
||||||
|
return actual, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store sets the counter for a key.
|
||||||
|
func (cache *Cache) Store(key string, counter *Counter) {
|
||||||
|
cache.counters.Store(key, counter)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete deletes the value for a key.
|
||||||
|
func (cache *Cache) Delete(key string) {
|
||||||
|
cache.counters.Delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Range calls handler sequentially for each key and value present in the cache.
|
||||||
|
// If handler returns false, range stops the iteration.
|
||||||
|
func (cache *Cache) Range(handler func(key string, counter *Counter)) {
|
||||||
|
cache.counters.Range(func(k interface{}, v interface{}) bool {
|
||||||
|
if v == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
key := k.(string)
|
||||||
|
counter := v.(*Counter)
|
||||||
|
|
||||||
|
handler(key, counter)
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment increments given value on key.
|
||||||
|
// If key is undefined or expired, it will create it.
|
||||||
|
func (cache *Cache) Increment(key string, value int64, duration time.Duration) (int64, time.Time) {
|
||||||
|
expiration := time.Now().Add(duration).UnixNano()
|
||||||
|
|
||||||
|
// If counter is in cache, try to load it first.
|
||||||
|
counter, loaded := cache.Load(key)
|
||||||
|
if loaded {
|
||||||
|
value, expiration = counter.Increment(value, expiration)
|
||||||
|
return value, time.Unix(0, expiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's not in cache, try to atomically create it.
|
||||||
|
// We do that in two step to reduce memory allocation.
|
||||||
|
counter, loaded = cache.LoadOrStore(key, &Counter{
|
||||||
|
mutex: sync.RWMutex{},
|
||||||
|
value: value,
|
||||||
|
expiration: expiration,
|
||||||
|
})
|
||||||
|
if loaded {
|
||||||
|
value, expiration = counter.Increment(value, expiration)
|
||||||
|
return value, time.Unix(0, expiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, it has been created, return given value.
|
||||||
|
return value, time.Unix(0, expiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns key's value and expiration.
|
||||||
|
func (cache *Cache) Get(key string, duration time.Duration) (int64, time.Time) {
|
||||||
|
expiration := time.Now().Add(duration).UnixNano()
|
||||||
|
|
||||||
|
counter, ok := cache.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return 0, time.Unix(0, expiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
value, expiration := counter.Load(expiration)
|
||||||
|
return value, time.Unix(0, expiration)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean will deleted any expired keys.
|
||||||
|
func (cache *Cache) Clean() {
|
||||||
|
cache.Range(func(key string, counter *Counter) {
|
||||||
|
if counter.Expired() {
|
||||||
|
cache.Delete(key)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset changes the key's value and resets the expiration.
|
||||||
|
func (cache *Cache) Reset(key string, duration time.Duration) (int64, time.Time) {
|
||||||
|
cache.Delete(key)
|
||||||
|
|
||||||
|
expiration := time.Now().Add(duration).UnixNano()
|
||||||
|
return 0, time.Unix(0, expiration)
|
||||||
|
}
|
|
@ -0,0 +1,82 @@
|
||||||
|
package memory
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ulule/limiter/v3"
|
||||||
|
"github.com/ulule/limiter/v3/drivers/store/common"
|
||||||
|
"github.com/ulule/limiter/v3/internal/bytebuffer"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store is the in-memory store.
|
||||||
|
type Store struct {
|
||||||
|
// Prefix used for the key.
|
||||||
|
Prefix string
|
||||||
|
// cache used to store values in-memory.
|
||||||
|
cache *CacheWrapper
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStore creates a new instance of memory store with defaults.
|
||||||
|
func NewStore() limiter.Store {
|
||||||
|
return NewStoreWithOptions(limiter.StoreOptions{
|
||||||
|
Prefix: limiter.DefaultPrefix,
|
||||||
|
CleanUpInterval: limiter.DefaultCleanUpInterval,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStoreWithOptions creates a new instance of memory store with options.
|
||||||
|
func NewStoreWithOptions(options limiter.StoreOptions) limiter.Store {
|
||||||
|
return &Store{
|
||||||
|
Prefix: options.Prefix,
|
||||||
|
cache: NewCache(options.CleanUpInterval),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the limit for given identifier.
|
||||||
|
func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
|
||||||
|
buffer := bytebuffer.New()
|
||||||
|
defer buffer.Close()
|
||||||
|
buffer.Concat(store.Prefix, ":", key)
|
||||||
|
|
||||||
|
count, expiration := store.cache.Increment(buffer.String(), 1, rate.Period)
|
||||||
|
|
||||||
|
lctx := common.GetContextFromState(time.Now(), rate, expiration, count)
|
||||||
|
return lctx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment increments the limit by given count & returns the new limit value for given identifier.
|
||||||
|
func (store *Store) Increment(ctx context.Context, key string, count int64, rate limiter.Rate) (limiter.Context, error) {
|
||||||
|
buffer := bytebuffer.New()
|
||||||
|
defer buffer.Close()
|
||||||
|
buffer.Concat(store.Prefix, ":", key)
|
||||||
|
|
||||||
|
newCount, expiration := store.cache.Increment(buffer.String(), count, rate.Period)
|
||||||
|
|
||||||
|
lctx := common.GetContextFromState(time.Now(), rate, expiration, newCount)
|
||||||
|
return lctx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peek returns the limit for given identifier, without modification on current values.
|
||||||
|
func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
|
||||||
|
buffer := bytebuffer.New()
|
||||||
|
defer buffer.Close()
|
||||||
|
buffer.Concat(store.Prefix, ":", key)
|
||||||
|
|
||||||
|
count, expiration := store.cache.Get(buffer.String(), rate.Period)
|
||||||
|
|
||||||
|
lctx := common.GetContextFromState(time.Now(), rate, expiration, count)
|
||||||
|
return lctx, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset returns the limit for given identifier.
|
||||||
|
func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
|
||||||
|
buffer := bytebuffer.New()
|
||||||
|
defer buffer.Close()
|
||||||
|
buffer.Concat(store.Prefix, ":", key)
|
||||||
|
|
||||||
|
count, expiration := store.cache.Reset(buffer.String(), rate.Period)
|
||||||
|
|
||||||
|
lctx := common.GetContextFromState(time.Now(), rate, expiration, count)
|
||||||
|
return lctx, nil
|
||||||
|
}
|
|
@ -0,0 +1,58 @@
|
||||||
|
package bytebuffer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ByteBuffer is a wrapper around a slice to reduce memory allocation while handling blob of data.
|
||||||
|
type ByteBuffer struct {
|
||||||
|
blob []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new ByteBuffer instance.
|
||||||
|
func New() *ByteBuffer {
|
||||||
|
entry := bufferPool.Get().(*ByteBuffer)
|
||||||
|
entry.blob = entry.blob[:0]
|
||||||
|
return entry
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes returns the content buffer.
|
||||||
|
func (buffer *ByteBuffer) Bytes() []byte {
|
||||||
|
return buffer.blob
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the content buffer.
|
||||||
|
func (buffer *ByteBuffer) String() string {
|
||||||
|
// Copied from strings.(*Builder).String
|
||||||
|
return *(*string)(unsafe.Pointer(&buffer.blob)) // nolint: gosec
|
||||||
|
}
|
||||||
|
|
||||||
|
// Concat appends given arguments to blob content
|
||||||
|
func (buffer *ByteBuffer) Concat(args ...string) {
|
||||||
|
for i := range args {
|
||||||
|
buffer.blob = append(buffer.blob, args[i]...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close recycles underlying resources of encoder.
|
||||||
|
func (buffer *ByteBuffer) Close() {
|
||||||
|
// Proper usage of a sync.Pool requires each entry to have approximately
|
||||||
|
// the same memory cost. To obtain this property when the stored type
|
||||||
|
// contains a variably-sized buffer, we add a hard limit on the maximum buffer
|
||||||
|
// to place back in the pool.
|
||||||
|
//
|
||||||
|
// See https://golang.org/issue/23199
|
||||||
|
if buffer != nil && cap(buffer.blob) < (1<<16) {
|
||||||
|
bufferPool.Put(buffer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A byte buffer pool to reduce memory allocation pressure.
|
||||||
|
var bufferPool = &sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return &ByteBuffer{
|
||||||
|
blob: make([]byte, 0, 1024),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
package limiter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
// Context
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
|
||||||
|
// Context is the limit context.
|
||||||
|
type Context struct {
|
||||||
|
Limit int64
|
||||||
|
Remaining int64
|
||||||
|
Reset int64
|
||||||
|
Reached bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
// Limiter
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
|
||||||
|
// Limiter is the limiter instance.
|
||||||
|
type Limiter struct {
|
||||||
|
Store Store
|
||||||
|
Rate Rate
|
||||||
|
Options Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns an instance of Limiter.
|
||||||
|
func New(store Store, rate Rate, options ...Option) *Limiter {
|
||||||
|
opt := Options{
|
||||||
|
IPv4Mask: DefaultIPv4Mask,
|
||||||
|
IPv6Mask: DefaultIPv6Mask,
|
||||||
|
TrustForwardHeader: false,
|
||||||
|
}
|
||||||
|
for _, o := range options {
|
||||||
|
o(&opt)
|
||||||
|
}
|
||||||
|
return &Limiter{
|
||||||
|
Store: store,
|
||||||
|
Rate: rate,
|
||||||
|
Options: opt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the limit for given identifier.
|
||||||
|
func (limiter *Limiter) Get(ctx context.Context, key string) (Context, error) {
|
||||||
|
return limiter.Store.Get(ctx, key, limiter.Rate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peek returns the limit for given identifier, without modification on current values.
|
||||||
|
func (limiter *Limiter) Peek(ctx context.Context, key string) (Context, error) {
|
||||||
|
return limiter.Store.Peek(ctx, key, limiter.Rate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset sets the limit for given identifier to zero.
|
||||||
|
func (limiter *Limiter) Reset(ctx context.Context, key string) (Context, error) {
|
||||||
|
return limiter.Store.Reset(ctx, key, limiter.Rate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment increments the limit by given count & gives back the new limit for given identifier
|
||||||
|
func (limiter *Limiter) Increment(ctx context.Context, key string, count int64) (Context, error) {
|
||||||
|
return limiter.Store.Increment(ctx, key, count, limiter.Rate)
|
||||||
|
}
|
|
@ -0,0 +1,137 @@
|
||||||
|
package limiter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// DefaultIPv4Mask defines the default IPv4 mask used to obtain user IP.
|
||||||
|
DefaultIPv4Mask = net.CIDRMask(32, 32)
|
||||||
|
// DefaultIPv6Mask defines the default IPv6 mask used to obtain user IP.
|
||||||
|
DefaultIPv6Mask = net.CIDRMask(128, 128)
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetIP returns IP address from request.
|
||||||
|
// If options is defined and either TrustForwardHeader is true or ClientIPHeader is defined,
|
||||||
|
// it will lookup IP in HTTP headers.
|
||||||
|
// Please be advised that using this option could be insecure (ie: spoofed) if your reverse
|
||||||
|
// proxy is not configured properly to forward a trustworthy client IP.
|
||||||
|
// Please read the section "Limiter behind a reverse proxy" in the README for further information.
|
||||||
|
func (limiter *Limiter) GetIP(r *http.Request) net.IP {
|
||||||
|
return GetIP(r, limiter.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIPWithMask returns IP address from request by applying a mask.
|
||||||
|
// If options is defined and either TrustForwardHeader is true or ClientIPHeader is defined,
|
||||||
|
// it will lookup IP in HTTP headers.
|
||||||
|
// Please be advised that using this option could be insecure (ie: spoofed) if your reverse
|
||||||
|
// proxy is not configured properly to forward a trustworthy client IP.
|
||||||
|
// Please read the section "Limiter behind a reverse proxy" in the README for further information.
|
||||||
|
func (limiter *Limiter) GetIPWithMask(r *http.Request) net.IP {
|
||||||
|
return GetIPWithMask(r, limiter.Options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIPKey extracts IP from request and returns hashed IP to use as store key.
|
||||||
|
// If options is defined and either TrustForwardHeader is true or ClientIPHeader is defined,
|
||||||
|
// it will lookup IP in HTTP headers.
|
||||||
|
// Please be advised that using this option could be insecure (ie: spoofed) if your reverse
|
||||||
|
// proxy is not configured properly to forward a trustworthy client IP.
|
||||||
|
// Please read the section "Limiter behind a reverse proxy" in the README for further information.
|
||||||
|
func (limiter *Limiter) GetIPKey(r *http.Request) string {
|
||||||
|
return limiter.GetIPWithMask(r).String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIP returns IP address from request.
|
||||||
|
// If options is defined and either TrustForwardHeader is true or ClientIPHeader is defined,
|
||||||
|
// it will lookup IP in HTTP headers.
|
||||||
|
// Please be advised that using this option could be insecure (ie: spoofed) if your reverse
|
||||||
|
// proxy is not configured properly to forward a trustworthy client IP.
|
||||||
|
// Please read the section "Limiter behind a reverse proxy" in the README for further information.
|
||||||
|
func GetIP(r *http.Request, options ...Options) net.IP {
|
||||||
|
if len(options) >= 1 {
|
||||||
|
if options[0].ClientIPHeader != "" {
|
||||||
|
ip := getIPFromHeader(r, options[0].ClientIPHeader)
|
||||||
|
if ip != nil {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if options[0].TrustForwardHeader {
|
||||||
|
ip := getIPFromXFFHeader(r)
|
||||||
|
if ip != nil {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
ip = getIPFromHeader(r, "X-Real-IP")
|
||||||
|
if ip != nil {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteAddr := strings.TrimSpace(r.RemoteAddr)
|
||||||
|
host, _, err := net.SplitHostPort(remoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return net.ParseIP(remoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return net.ParseIP(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIPWithMask returns IP address from request by applying a mask.
|
||||||
|
// If options is defined and either TrustForwardHeader is true or ClientIPHeader is defined,
|
||||||
|
// it will lookup IP in HTTP headers.
|
||||||
|
// Please be advised that using this option could be insecure (ie: spoofed) if your reverse
|
||||||
|
// proxy is not configured properly to forward a trustworthy client IP.
|
||||||
|
// Please read the section "Limiter behind a reverse proxy" in the README for further information.
|
||||||
|
func GetIPWithMask(r *http.Request, options ...Options) net.IP {
|
||||||
|
if len(options) == 0 {
|
||||||
|
return GetIP(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := GetIP(r, options[0])
|
||||||
|
if ip.To4() != nil {
|
||||||
|
return ip.Mask(options[0].IPv4Mask)
|
||||||
|
}
|
||||||
|
if ip.To16() != nil {
|
||||||
|
return ip.Mask(options[0].IPv6Mask)
|
||||||
|
}
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIPFromXFFHeader(r *http.Request) net.IP {
|
||||||
|
headers := r.Header.Values("X-Forwarded-For")
|
||||||
|
if len(headers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := []string{}
|
||||||
|
for _, header := range headers {
|
||||||
|
parts = append(parts, strings.Split(header, ",")...)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range parts {
|
||||||
|
part := strings.TrimSpace(parts[i])
|
||||||
|
ip := net.ParseIP(part)
|
||||||
|
if ip != nil {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIPFromHeader(r *http.Request, name string) net.IP {
|
||||||
|
header := strings.TrimSpace(r.Header.Get(name))
|
||||||
|
if header == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ip := net.ParseIP(header)
|
||||||
|
if ip != nil {
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -0,0 +1,61 @@
|
||||||
|
package limiter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Option is a functional option.
|
||||||
|
type Option func(*Options)
|
||||||
|
|
||||||
|
// Options are limiter options.
|
||||||
|
type Options struct {
|
||||||
|
// IPv4Mask defines the mask used to obtain a IPv4 address.
|
||||||
|
IPv4Mask net.IPMask
|
||||||
|
// IPv6Mask defines the mask used to obtain a IPv6 address.
|
||||||
|
IPv6Mask net.IPMask
|
||||||
|
// TrustForwardHeader enable parsing of X-Real-IP and X-Forwarded-For headers to obtain user IP.
|
||||||
|
// Please be advised that using this option could be insecure (ie: spoofed) if your reverse
|
||||||
|
// proxy is not configured properly to forward a trustworthy client IP.
|
||||||
|
// Please read the section "Limiter behind a reverse proxy" in the README for further information.
|
||||||
|
TrustForwardHeader bool
|
||||||
|
// ClientIPHeader defines a custom header (likely defined by your CDN or Cloud provider) to obtain user IP.
|
||||||
|
// If configured, this option will override "TrustForwardHeader" option.
|
||||||
|
// Please be advised that using this option could be insecure (ie: spoofed) if your reverse
|
||||||
|
// proxy is not configured properly to forward a trustworthy client IP.
|
||||||
|
// Please read the section "Limiter behind a reverse proxy" in the README for further information.
|
||||||
|
ClientIPHeader string
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithIPv4Mask will configure the limiter to use given mask for IPv4 address.
|
||||||
|
func WithIPv4Mask(mask net.IPMask) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.IPv4Mask = mask
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithIPv6Mask will configure the limiter to use given mask for IPv6 address.
|
||||||
|
func WithIPv6Mask(mask net.IPMask) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.IPv6Mask = mask
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTrustForwardHeader will configure the limiter to trust X-Real-IP and X-Forwarded-For headers.
|
||||||
|
// Please be advised that using this option could be insecure (ie: spoofed) if your reverse
|
||||||
|
// proxy is not configured properly to forward a trustworthy client IP.
|
||||||
|
// Please read the section "Limiter behind a reverse proxy" in the README for further information.
|
||||||
|
func WithTrustForwardHeader(enable bool) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.TrustForwardHeader = enable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithClientIPHeader will configure the limiter to use a custom header to obtain user IP.
|
||||||
|
// Please be advised that using this option could be insecure (ie: spoofed) if your reverse
|
||||||
|
// proxy is not configured properly to forward a trustworthy client IP.
|
||||||
|
// Please read the section "Limiter behind a reverse proxy" in the README for further information.
|
||||||
|
func WithClientIPHeader(header string) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.ClientIPHeader = header
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,53 @@
|
||||||
|
package limiter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Rate is the rate.
|
||||||
|
type Rate struct {
|
||||||
|
Formatted string
|
||||||
|
Period time.Duration
|
||||||
|
Limit int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRateFromFormatted returns the rate from the formatted version.
|
||||||
|
func NewRateFromFormatted(formatted string) (Rate, error) {
|
||||||
|
rate := Rate{}
|
||||||
|
|
||||||
|
values := strings.Split(formatted, "-")
|
||||||
|
if len(values) != 2 {
|
||||||
|
return rate, errors.Errorf("incorrect format '%s'", formatted)
|
||||||
|
}
|
||||||
|
|
||||||
|
periods := map[string]time.Duration{
|
||||||
|
"S": time.Second, // Second
|
||||||
|
"M": time.Minute, // Minute
|
||||||
|
"H": time.Hour, // Hour
|
||||||
|
"D": time.Hour * 24, // Day
|
||||||
|
}
|
||||||
|
|
||||||
|
limit, period := values[0], strings.ToUpper(values[1])
|
||||||
|
|
||||||
|
p, ok := periods[period]
|
||||||
|
if !ok {
|
||||||
|
return rate, errors.Errorf("incorrect period '%s'", period)
|
||||||
|
}
|
||||||
|
|
||||||
|
l, err := strconv.ParseInt(limit, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return rate, errors.Errorf("incorrect limit '%s'", limit)
|
||||||
|
}
|
||||||
|
|
||||||
|
rate = Rate{
|
||||||
|
Formatted: formatted,
|
||||||
|
Period: p,
|
||||||
|
Limit: l,
|
||||||
|
}
|
||||||
|
|
||||||
|
return rate, nil
|
||||||
|
}
|
|
@ -0,0 +1,34 @@
|
||||||
|
package limiter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store is the common interface for limiter stores.
|
||||||
|
type Store interface {
|
||||||
|
// Get returns the limit for given identifier.
|
||||||
|
Get(ctx context.Context, key string, rate Rate) (Context, error)
|
||||||
|
// Peek returns the limit for given identifier, without modification on current values.
|
||||||
|
Peek(ctx context.Context, key string, rate Rate) (Context, error)
|
||||||
|
// Reset resets the limit to zero for given identifier.
|
||||||
|
Reset(ctx context.Context, key string, rate Rate) (Context, error)
|
||||||
|
// Increment increments the limit by given count & gives back the new limit for given identifier
|
||||||
|
Increment(ctx context.Context, key string, count int64, rate Rate) (Context, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreOptions are options for store.
|
||||||
|
type StoreOptions struct {
|
||||||
|
// Prefix is the prefix to use for the key.
|
||||||
|
Prefix string
|
||||||
|
|
||||||
|
// MaxRetry is the maximum number of retry under race conditions on redis store.
|
||||||
|
// Deprecated: this option is no longer required since all operations are atomic now.
|
||||||
|
MaxRetry int
|
||||||
|
|
||||||
|
// CleanUpInterval is the interval for cleanup (run garbage collection) on stale entries on memory store.
|
||||||
|
// Setting this to a low value will optimize memory consumption, but will likely
|
||||||
|
// reduce performance and increase lock contention.
|
||||||
|
// Setting this to a high value will maximum throughput, but will increase the memory footprint.
|
||||||
|
CleanUpInterval time.Duration
|
||||||
|
}
|
|
@ -230,7 +230,7 @@ github.com/json-iterator/go
|
||||||
# github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51
|
# github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51
|
||||||
## explicit
|
## explicit
|
||||||
github.com/kballard/go-shellquote
|
github.com/kballard/go-shellquote
|
||||||
# github.com/klauspost/compress v1.13.6
|
# github.com/klauspost/compress v1.15.0
|
||||||
## explicit; go 1.15
|
## explicit; go 1.15
|
||||||
github.com/klauspost/compress/s2
|
github.com/klauspost/compress/s2
|
||||||
# github.com/klauspost/cpuid v1.3.1
|
# github.com/klauspost/cpuid v1.3.1
|
||||||
|
@ -295,6 +295,9 @@ github.com/pelletier/go-toml/v2
|
||||||
github.com/pelletier/go-toml/v2/internal/ast
|
github.com/pelletier/go-toml/v2/internal/ast
|
||||||
github.com/pelletier/go-toml/v2/internal/danger
|
github.com/pelletier/go-toml/v2/internal/danger
|
||||||
github.com/pelletier/go-toml/v2/internal/tracker
|
github.com/pelletier/go-toml/v2/internal/tracker
|
||||||
|
# github.com/pkg/errors v0.9.1
|
||||||
|
## explicit
|
||||||
|
github.com/pkg/errors
|
||||||
# github.com/pmezard/go-difflib v1.0.0
|
# github.com/pmezard/go-difflib v1.0.0
|
||||||
## explicit
|
## explicit
|
||||||
github.com/pmezard/go-difflib/difflib
|
github.com/pmezard/go-difflib/difflib
|
||||||
|
@ -563,6 +566,13 @@ github.com/tmthrgd/go-hex
|
||||||
# github.com/ugorji/go/codec v1.2.7
|
# github.com/ugorji/go/codec v1.2.7
|
||||||
## explicit; go 1.11
|
## explicit; go 1.11
|
||||||
github.com/ugorji/go/codec
|
github.com/ugorji/go/codec
|
||||||
|
# github.com/ulule/limiter/v3 v3.10.0
|
||||||
|
## explicit; go 1.17
|
||||||
|
github.com/ulule/limiter/v3
|
||||||
|
github.com/ulule/limiter/v3/drivers/middleware/gin
|
||||||
|
github.com/ulule/limiter/v3/drivers/store/common
|
||||||
|
github.com/ulule/limiter/v3/drivers/store/memory
|
||||||
|
github.com/ulule/limiter/v3/internal/bytebuffer
|
||||||
# github.com/uptrace/bun v1.1.7
|
# github.com/uptrace/bun v1.1.7
|
||||||
## explicit; go 1.17
|
## explicit; go 1.17
|
||||||
github.com/uptrace/bun
|
github.com/uptrace/bun
|
||||||
|
|
Loading…
Reference in New Issue