accounting: factor --tpslimit code into accounting from fshttp

This commit is contained in:
Nick Craig-Wood 2021-01-06 12:06:00 +00:00
parent d58fdb10db
commit 50344e7792
4 changed files with 80 additions and 24 deletions

37
fs/accounting/tpslimit.go Normal file
View File

@ -0,0 +1,37 @@
package accounting
import (
"context"
"github.com/rclone/rclone/fs"
"golang.org/x/time/rate"
)
var (
tpsBucket *rate.Limiter // for limiting number of http transactions per second
)
// StartLimitTPS starts the token bucket for transactions per second
// limiting if necessary
func StartLimitTPS(ctx context.Context) {
ci := fs.GetConfig(ctx)
if ci.TPSLimit > 0 {
tpsBurst := ci.TPSLimitBurst
if tpsBurst < 1 {
tpsBurst = 1
}
tpsBucket = rate.NewLimiter(rate.Limit(ci.TPSLimit), tpsBurst)
fs.Infof(nil, "Starting transaction limiter: max %g transactions/s with burst %d", ci.TPSLimit, tpsBurst)
}
}
// LimitTPS limits the number of transactions per second if enabled.
// It should be called once per transaction.
func LimitTPS(ctx context.Context) {
if tpsBucket != nil {
tbErr := tpsBucket.Wait(ctx)
if tbErr != nil && tbErr != context.Canceled {
fs.Errorf(nil, "HTTP token bucket error: %v", tbErr)
}
}
}

View File

@ -0,0 +1,39 @@
package accounting
import (
"context"
"testing"
"time"
"github.com/rclone/rclone/fs"
"github.com/stretchr/testify/assert"
)
func TestLimitTPS(t *testing.T) {
timeTransactions := func(n int, minTime, maxTime time.Duration) {
start := time.Now()
for i := 0; i < n; i++ {
LimitTPS(context.Background())
}
dt := time.Since(start)
assert.True(t, dt >= minTime && dt <= maxTime, "Expecting time between %v and %v, got %v", minTime, maxTime, dt)
}
t.Run("Off", func(t *testing.T) {
assert.Nil(t, tpsBucket)
timeTransactions(100, 0*time.Millisecond, 100*time.Millisecond)
})
t.Run("On", func(t *testing.T) {
ctx, ci := fs.AddConfig(context.Background())
ci.TPSLimit = 100.0
ci.TPSLimitBurst = 0
StartLimitTPS(ctx)
assert.NotNil(t, tpsBucket)
defer func() {
tpsBucket = nil
}()
timeTransactions(100, 900*time.Millisecond, 2000*time.Millisecond)
})
}

View File

@ -34,7 +34,6 @@ import (
"github.com/rclone/rclone/fs/config/configstruct"
"github.com/rclone/rclone/fs/config/obscure"
"github.com/rclone/rclone/fs/driveletter"
"github.com/rclone/rclone/fs/fshttp"
"github.com/rclone/rclone/fs/fspath"
"github.com/rclone/rclone/fs/rc"
"github.com/rclone/rclone/lib/random"
@ -236,7 +235,7 @@ func LoadConfig(ctx context.Context) {
accounting.StartTokenTicker(ctx)
// Start the transactions per second limiter
fshttp.StartHTTPTokenBucket(ctx)
accounting.StartLimitTPS(ctx)
}
var errorConfigFileNotFound = errors.New("config file not found")

View File

@ -16,9 +16,9 @@ import (
"time"
"github.com/rclone/rclone/fs"
"github.com/rclone/rclone/fs/accounting"
"github.com/rclone/rclone/lib/structs"
"golang.org/x/net/publicsuffix"
"golang.org/x/time/rate"
)
const (
@ -29,24 +29,10 @@ const (
var (
transport http.RoundTripper
noTransport = new(sync.Once)
tpsBucket *rate.Limiter // for limiting number of http transactions per second
cookieJar, _ = cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
logMutex sync.Mutex
)
// StartHTTPTokenBucket starts the token bucket if necessary
func StartHTTPTokenBucket(ctx context.Context) {
ci := fs.GetConfig(ctx)
if ci.TPSLimit > 0 {
tpsBurst := ci.TPSLimitBurst
if tpsBurst < 1 {
tpsBurst = 1
}
tpsBucket = rate.NewLimiter(rate.Limit(ci.TPSLimit), tpsBurst)
fs.Infof(nil, "Starting HTTP transaction limiter: max %g transactions/s with burst %d", ci.TPSLimit, tpsBurst)
}
}
// A net.Conn that sets a deadline for every Read or Write operation
type timeoutConn struct {
net.Conn
@ -309,13 +295,8 @@ func cleanAuths(buf []byte) []byte {
// RoundTrip implements the RoundTripper interface.
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
// Get transactions per second token first if limiting
if tpsBucket != nil {
tbErr := tpsBucket.Wait(req.Context())
if tbErr != nil && tbErr != context.Canceled {
fs.Errorf(nil, "HTTP token bucket error: %v", tbErr)
}
}
// Limit transactions per second if required
accounting.LimitTPS(req.Context())
// Force user agent
req.Header.Set("User-Agent", t.userAgent)
// Set user defined headers