Redo http Transport code

* Insert User-Agent in Transport - fixes #199
  * Update timeouts to use Context
  * Modernise transport
This commit is contained in:
Nick Craig-Wood 2016-09-10 19:17:43 +01:00
parent 5c91623148
commit 0cb9bb3b54
15 changed files with 273 additions and 133 deletions

View File

@ -172,7 +172,6 @@ func NewFs(name, root string) (fs.Fs, error) {
}
c := acd.NewClient(oAuthClient)
c.UserAgent = fs.UserAgent
f := &Fs{
name: name,
root: root,

View File

@ -932,7 +932,6 @@ func (o *Object) httpResponse(method string) (res *http.Response, err error) {
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", fs.UserAgent)
err = o.fs.pacer.Call(func() (bool, error) {
res, err = o.fs.client.Do(req)
return shouldRetry(err)

View File

@ -84,7 +84,6 @@ func (f *Fs) Upload(in io.Reader, size int64, contentType string, info *drive.Fi
req.Header.Set("Content-Type", "application/json; charset=UTF-8")
req.Header.Set("X-Upload-Content-Type", contentType)
req.Header.Set("X-Upload-Content-Length", fmt.Sprintf("%v", size))
req.Header.Set("User-Agent", fs.UserAgent)
res, err = f.client.Do(req)
if err == nil {
defer googleapi.CloseBody(res)
@ -118,7 +117,6 @@ func (rx *resumableUpload) makeRequest(start int64, body []byte) *http.Request {
req.Header.Set("Content-Range", fmt.Sprintf("bytes */%v", rx.ContentLength))
}
req.Header.Set("Content-Type", rx.MediaType)
req.Header.Set("User-Agent", fs.UserAgent)
return req
}

View File

@ -9,14 +9,12 @@ import (
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"log"
"math"
"net/http"
"os"
"os/user"
"path"
@ -27,7 +25,6 @@ import (
"unicode/utf8"
"github.com/Unknwon/goconfig"
"github.com/mreiferson/go-httpclient"
"github.com/pkg/errors"
"github.com/spf13/pflag"
"golang.org/x/crypto/nacl/secretbox"
@ -304,62 +301,6 @@ type ConfigInfo struct {
NoUpdateModTime bool
}
// Transport returns an http.RoundTripper with the correct timeouts
func (ci *ConfigInfo) Transport() http.RoundTripper {
t := &httpclient.Transport{
Proxy: http.ProxyFromEnvironment,
MaxIdleConnsPerHost: ci.Checkers + ci.Transfers + 1,
// ConnectTimeout, if non-zero, is the maximum amount of time a dial will wait for
// a connect to complete.
ConnectTimeout: ci.ConnectTimeout,
// ResponseHeaderTimeout, if non-zero, specifies the amount of
// time to wait for a server's response headers after fully
// writing the request (including its body, if any). This
// time does not include the time to read the response body.
ResponseHeaderTimeout: ci.Timeout,
// RequestTimeout, if non-zero, specifies the amount of time for the entire
// request to complete (including all of the above timeouts + entire response body).
// This should never be less than the sum total of the above two timeouts.
//RequestTimeout: NOT SET,
// ReadWriteTimeout, if non-zero, will set a deadline for every Read and
// Write operation on the request connection.
ReadWriteTimeout: ci.Timeout,
// InsecureSkipVerify controls whether a client verifies the
// server's certificate chain and host name.
// If InsecureSkipVerify is true, TLS accepts any certificate
// presented by the server and any host name in that certificate.
// In this mode, TLS is susceptible to man-in-the-middle attacks.
// This should be used only for testing.
TLSClientConfig: &tls.Config{InsecureSkipVerify: ci.InsecureSkipVerify},
// DisableCompression, if true, prevents the Transport from
// requesting compression with an "Accept-Encoding: gzip"
// request header when the Request contains no existing
// Accept-Encoding value. If the Transport requests gzip on
// its own and gets a gzipped response, it's transparently
// decoded in the Response.Body. However, if the user
// explicitly requested gzip it is not automatically
// uncompressed.
DisableCompression: *noGzip,
}
if ci.DumpHeaders || ci.DumpBodies {
return NewLoggedTransport(t, ci.DumpBodies)
}
return t
}
// Client returns an http.Client with the correct timeouts
func (ci *ConfigInfo) Client() *http.Client {
return &http.Client{
Transport: ci.Transport(),
}
}
// Find the config directory
func configHome() string {
// Find users home directory

View File

@ -26,7 +26,7 @@ const (
// Globals
var (
// UserAgent for Fs which can set it
// UserAgent set in the default Transport
UserAgent = "rclone/" + Version
// Filesystem registry
fsRegistry []*RegInfo

169
fs/http.go Normal file
View File

@ -0,0 +1,169 @@
// The HTTP based parts of the config, Transport and Client
package fs
import (
"crypto/tls"
"net"
"net/http"
"net/http/httputil"
"reflect"
"sync"
"time"
)
const (
separatorReq = ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
separatorResp = "<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
)
var (
transport http.RoundTripper
noTransport sync.Once
)
// A net.Conn that sets a deadline for every Read or Write operation
type timeoutConn struct {
net.Conn
readTimer *time.Timer
writeTimer *time.Timer
timeout time.Duration
_cancel func()
off time.Time
}
// create a timeoutConn using the timeout
func newTimeoutConn(conn net.Conn, timeout time.Duration) *timeoutConn {
return &timeoutConn{
Conn: conn,
timeout: timeout,
}
}
// Read bytes doing timeouts
func (c *timeoutConn) Read(b []byte) (n int, err error) {
err = c.Conn.SetReadDeadline(time.Now().Add(c.timeout))
if err != nil {
return n, err
}
n, err = c.Conn.Read(b)
cerr := c.Conn.SetReadDeadline(c.off)
if cerr != nil {
err = cerr
}
return n, err
}
// Write bytes doing timeouts
func (c *timeoutConn) Write(b []byte) (n int, err error) {
err = c.Conn.SetWriteDeadline(time.Now().Add(c.timeout))
if err != nil {
return n, err
}
n, err = c.Conn.Write(b)
cerr := c.Conn.SetWriteDeadline(c.off)
if cerr != nil {
err = cerr
}
return n, err
}
// setDefaults for a from b
//
// Copy the public members from b to a. We can't just use a struct
// copy as Transport contains a private mutex.
func setDefaults(a, b interface{}) {
pt := reflect.TypeOf(a)
t := pt.Elem()
va := reflect.ValueOf(a).Elem()
vb := reflect.ValueOf(b).Elem()
for i := 0; i < t.NumField(); i++ {
aField := va.Field(i)
// Set a from b if it is public
if aField.CanSet() {
bField := vb.Field(i)
aField.Set(bField)
}
}
}
// Transport returns an http.RoundTripper with the correct timeouts
func (ci *ConfigInfo) Transport() http.RoundTripper {
noTransport.Do(func() {
// Start with a sensible set of defaults then override.
// This also means we get new stuff when it gets added to go
t := new(http.Transport)
setDefaults(t, http.DefaultTransport.(*http.Transport))
t.Proxy = http.ProxyFromEnvironment
t.MaxIdleConnsPerHost = 4 * (ci.Checkers + ci.Transfers + 1)
t.TLSHandshakeTimeout = ci.ConnectTimeout
t.ResponseHeaderTimeout = ci.ConnectTimeout
t.TLSClientConfig = &tls.Config{InsecureSkipVerify: ci.InsecureSkipVerify}
t.DisableCompression = *noGzip
// Set in http_old.go initTransport
// t.Dial
// Set in http_new.go initTransport
// t.DialContext
// t.IdelConnTimeout
// t.ExpectContinueTimeout
ci.initTransport(t)
// Wrap that http.Transport in our own transport
transport = NewTransport(t, ci.DumpHeaders, ci.DumpBodies)
})
return transport
}
// Client returns an http.Client with the correct timeouts
func (ci *ConfigInfo) Client() *http.Client {
return &http.Client{
Transport: ci.Transport(),
}
}
// Transport is a our http Transport which wraps an http.Transport
// * Sets the User Agent
// * Does logging
type Transport struct {
*http.Transport
logHeader bool
logBody bool
}
// NewTransport wraps the http.Transport passed in and logs all
// roundtrips including the body if logBody is set.
func NewTransport(transport *http.Transport, logHeader, logBody bool) *Transport {
return &Transport{
Transport: transport,
logHeader: logHeader,
logBody: logBody,
}
}
// RoundTrip implements the RoundTripper interface.
func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
// Force user agent
req.Header.Set("User-Agent", UserAgent)
// Log request
if t.logHeader {
buf, _ := httputil.DumpRequestOut(req, t.logBody)
Debug(nil, "%s", separatorReq)
Debug(nil, "%s", "HTTP REQUEST")
Debug(nil, "%s", string(buf))
Debug(nil, "%s", separatorReq)
}
// Do round trip
resp, err = t.Transport.RoundTrip(req)
// Log response
if t.logHeader {
Debug(nil, "%s", separatorResp)
Debug(nil, "%s", "HTTP RESPONSE")
if err != nil {
Debug(nil, "Error: %v", err)
} else {
buf, _ := httputil.DumpResponse(resp, t.logBody)
Debug(nil, "%s", string(buf))
}
Debug(nil, "%s", separatorResp)
}
return resp, err
}

34
fs/http_new.go Normal file
View File

@ -0,0 +1,34 @@
// HTTP parts go1.7+
//+build go1.7
package fs
import (
"context"
"net"
"net/http"
"time"
)
// dial with context and timeouts
func dialContextTimeout(ctx context.Context, network, address string, connectTimeout, timeout time.Duration) (net.Conn, error) {
dialer := net.Dialer{
Timeout: connectTimeout,
KeepAlive: 30 * time.Second,
}
c, err := dialer.DialContext(ctx, network, address)
if err != nil {
return c, err
}
return newTimeoutConn(c, timeout), nil
}
// Initialise the http.Transport for go1.7+
func (ci *ConfigInfo) initTransport(t *http.Transport) {
t.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
return dialContextTimeout(ctx, network, address, ci.ConnectTimeout, ci.Timeout)
}
t.IdleConnTimeout = 60 * time.Second
t.ExpectContinueTimeout = ci.ConnectTimeout
}

31
fs/http_old.go Normal file
View File

@ -0,0 +1,31 @@
// HTTP parts pre go1.7
//+build !go1.7
package fs
import (
"net"
"net/http"
"time"
)
// dial with timeouts
func dialTimeout(network, address string, connectTimeout, timeout time.Duration) (net.Conn, error) {
dialer := net.Dialer{
Timeout: connectTimeout,
KeepAlive: 30 * time.Second,
}
c, err := dialer.Dial(network, address)
if err != nil {
return c, err
}
return newTimeoutConn(c, timeout), nil
}
// Initialise the http.Transport for pre go1.7
func (ci *ConfigInfo) initTransport(t *http.Transport) {
t.Dial = func(network, address string) (net.Conn, error) {
return dialTimeout(network, address, ci.ConnectTimeout, ci.Timeout)
}
}

38
fs/http_test.go Normal file
View File

@ -0,0 +1,38 @@
package fs
import (
"fmt"
"net/http"
"testing"
"github.com/stretchr/testify/assert"
)
// returns the "%p" reprentation of the thing passed in
func ptr(p interface{}) string {
return fmt.Sprintf("%p", p)
}
func TestSetDefaults(t *testing.T) {
old := http.DefaultTransport.(*http.Transport)
new := new(http.Transport)
setDefaults(new, old)
// Can't use assert.Equal or reflect.DeepEqual for this as it has functions in
// Check functions by comparing the "%p" representations of them
assert.Equal(t, ptr(old.Proxy), ptr(new.Proxy), "when checking .Proxy")
assert.Equal(t, ptr(old.DialContext), ptr(new.DialContext), "when checking .DialContext")
// Check the other public fields
assert.Equal(t, old.Dial, new.Dial, "when checking .Dial")
assert.Equal(t, old.DialTLS, new.DialTLS, "when checking .DialTLS")
assert.Equal(t, old.TLSClientConfig, new.TLSClientConfig, "when checking .TLSClientConfig")
assert.Equal(t, old.TLSHandshakeTimeout, new.TLSHandshakeTimeout, "when checking .TLSHandshakeTimeout")
assert.Equal(t, old.DisableKeepAlives, new.DisableKeepAlives, "when checking .DisableKeepAlives")
assert.Equal(t, old.DisableCompression, new.DisableCompression, "when checking .DisableCompression")
assert.Equal(t, old.MaxIdleConns, new.MaxIdleConns, "when checking .MaxIdleConns")
assert.Equal(t, old.MaxIdleConnsPerHost, new.MaxIdleConnsPerHost, "when checking .MaxIdleConnsPerHost")
assert.Equal(t, old.IdleConnTimeout, new.IdleConnTimeout, "when checking .IdleConnTimeout")
assert.Equal(t, old.ResponseHeaderTimeout, new.ResponseHeaderTimeout, "when checking .ResponseHeaderTimeout")
assert.Equal(t, old.ExpectContinueTimeout, new.ExpectContinueTimeout, "when checking .ExpectContinueTimeout")
assert.Equal(t, old.TLSNextProto, new.TLSNextProto, "when checking .TLSNextProto")
assert.Equal(t, old.MaxResponseHeaderBytes, new.MaxResponseHeaderBytes, "when checking .MaxResponseHeaderBytes")
}

View File

@ -1,60 +0,0 @@
// A logging http transport
package fs
import (
"net/http"
"net/http/httputil"
)
const (
separatorReq = ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>"
separatorResp = "<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<"
)
// LoggedTransport is an http transport which logs the traffic
type LoggedTransport struct {
wrapped http.RoundTripper
logBody bool
}
// NewLoggedTransport wraps the transport passed in and logs all roundtrips
// including the body if logBody is set.
func NewLoggedTransport(transport http.RoundTripper, logBody bool) *LoggedTransport {
return &LoggedTransport{
wrapped: transport,
logBody: logBody,
}
}
// CancelRequest cancels an in-flight request by closing its
// connection. CancelRequest should only be called after RoundTrip has
// returned.
func (t *LoggedTransport) CancelRequest(req *http.Request) {
if wrapped, ok := t.wrapped.(interface {
CancelRequest(*http.Request)
}); ok {
Debug(nil, "CANCEL REQUEST %v", req)
wrapped.CancelRequest(req)
}
}
// RoundTrip implements the RoundTripper interface.
func (t *LoggedTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
buf, _ := httputil.DumpRequestOut(req, t.logBody)
Debug(nil, "%s", separatorReq)
Debug(nil, "%s", "HTTP REQUEST")
Debug(nil, "%s", string(buf))
Debug(nil, "%s", separatorReq)
resp, err = t.wrapped.RoundTrip(req)
Debug(nil, "%s", separatorResp)
Debug(nil, "%s", "HTTP RESPONSE")
if err != nil {
Debug(nil, "Error: %v", err)
} else {
buf, _ = httputil.DumpResponse(resp, t.logBody)
Debug(nil, "%s", string(buf))
}
Debug(nil, "%s", separatorResp)
return resp, err
}

View File

@ -654,7 +654,6 @@ func (o *Object) Open() (in io.ReadCloser, err error) {
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", fs.UserAgent)
res, err := o.fs.client.Do(req)
if err != nil {
return nil, err

View File

@ -112,7 +112,6 @@ func (f *Fs) getCredentials() (err error) {
if err != nil {
return err
}
req.Header.Add("User-Agent", fs.UserAgent)
resp, err := f.client.Do(req)
if err != nil {
return err
@ -155,7 +154,6 @@ func NewFs(name, root string) (fs.Fs, error) {
// Make the swift Connection
c := &swiftLib.Connection{
Auth: newAuth(f),
UserAgent: fs.UserAgent,
ConnectTimeout: 10 * fs.Config.ConnectTimeout, // Use the timeouts in the transport
Timeout: 10 * fs.Config.Timeout, // Use the timeouts in the transport
Transport: fs.Config.Transport(),

View File

@ -31,7 +31,6 @@ func NewClient(c *http.Client) *Client {
errorHandler: defaultErrorHandler,
headers: make(map[string]string),
}
api.SetHeader("User-Agent", fs.UserAgent)
return api
}

View File

@ -324,10 +324,6 @@ func s3Connection(name string) (*s3.S3, *session.Session, error) {
c.Handlers.Sign.PushBackNamed(corehandlers.BuildContentLengthHandler)
c.Handlers.Sign.PushBack(signer)
}
// Add user agent
c.Handlers.Build.PushBack(func(r *request.Request) {
r.HTTPRequest.Header.Set("User-Agent", fs.UserAgent)
})
return c, ses, nil
}

View File

@ -162,7 +162,6 @@ func swiftConnection(name string) (*swift.Connection, error) {
ApiKey: apiKey,
AuthUrl: authURL,
AuthVersion: fs.ConfigFile.MustInt(name, "auth_version", 0),
UserAgent: fs.UserAgent,
Tenant: fs.ConfigFile.MustValue(name, "tenant"),
Region: fs.ConfigFile.MustValue(name, "region"),
Domain: fs.ConfigFile.MustValue(name, "domain"),