diff --git a/fs/fshttp/http.go b/fs/fshttp/http.go index a4e0ee54d..d384cee8f 100644 --- a/fs/fshttp/http.go +++ b/fs/fshttp/http.go @@ -12,11 +12,11 @@ import ( "net/http" "net/http/cookiejar" "net/http/httputil" - "reflect" "sync" "time" "github.com/rclone/rclone/fs" + "github.com/rclone/rclone/lib/structs" "golang.org/x/net/publicsuffix" "golang.org/x/time/rate" ) @@ -92,25 +92,6 @@ func (c *timeoutConn) Write(b []byte) (n int, err error) { return c.readOrWrite(c.Conn.Write, b) } -// setDefaults for a from b -// -// Copy the public members from b to a. We can't just use a struct -// copy as Transport contains a private mutex. -func setDefaults(a, b interface{}) { - pt := reflect.TypeOf(a) - t := pt.Elem() - va := reflect.ValueOf(a).Elem() - vb := reflect.ValueOf(b).Elem() - for i := 0; i < t.NumField(); i++ { - aField := va.Field(i) - // Set a from b if it is public - if aField.CanSet() { - bField := vb.Field(i) - aField.Set(bField) - } - } -} - // dial with context and timeouts func dialContextTimeout(ctx context.Context, network, address string, ci *fs.ConfigInfo) (net.Conn, error) { dialer := NewDialer(ci) @@ -134,7 +115,7 @@ func NewTransportCustom(ci *fs.ConfigInfo, customize func(*http.Transport)) http // Start with a sensible set of defaults then override. // This also means we get new stuff when it gets added to go t := new(http.Transport) - setDefaults(t, http.DefaultTransport.(*http.Transport)) + structs.SetDefaults(t, http.DefaultTransport.(*http.Transport)) t.Proxy = http.ProxyFromEnvironment t.MaxIdleConnsPerHost = 2 * (ci.Checkers + ci.Transfers + 1) t.MaxIdleConns = 2 * t.MaxIdleConnsPerHost diff --git a/fs/fshttp/http_test.go b/fs/fshttp/http_test.go index b16747dcd..24f440b43 100644 --- a/fs/fshttp/http_test.go +++ b/fs/fshttp/http_test.go @@ -1,42 +1,11 @@ package fshttp import ( - "fmt" - "net/http" "testing" "github.com/stretchr/testify/assert" ) -// returns the "%p" representation of the thing passed in -func ptr(p interface{}) string { - return fmt.Sprintf("%p", p) -} - -func TestSetDefaults(t *testing.T) { - old := http.DefaultTransport.(*http.Transport) - newT := new(http.Transport) - setDefaults(newT, old) - // Can't use assert.Equal or reflect.DeepEqual for this as it has functions in - // Check functions by comparing the "%p" representations of them - assert.Equal(t, ptr(old.Proxy), ptr(newT.Proxy), "when checking .Proxy") - assert.Equal(t, ptr(old.DialContext), ptr(newT.DialContext), "when checking .DialContext") - // Check the other public fields - assert.Equal(t, ptr(old.Dial), ptr(newT.Dial), "when checking .Dial") - assert.Equal(t, ptr(old.DialTLS), ptr(newT.DialTLS), "when checking .DialTLS") - assert.Equal(t, old.TLSClientConfig, newT.TLSClientConfig, "when checking .TLSClientConfig") - assert.Equal(t, old.TLSHandshakeTimeout, newT.TLSHandshakeTimeout, "when checking .TLSHandshakeTimeout") - assert.Equal(t, old.DisableKeepAlives, newT.DisableKeepAlives, "when checking .DisableKeepAlives") - assert.Equal(t, old.DisableCompression, newT.DisableCompression, "when checking .DisableCompression") - assert.Equal(t, old.MaxIdleConns, newT.MaxIdleConns, "when checking .MaxIdleConns") - assert.Equal(t, old.MaxIdleConnsPerHost, newT.MaxIdleConnsPerHost, "when checking .MaxIdleConnsPerHost") - assert.Equal(t, old.IdleConnTimeout, newT.IdleConnTimeout, "when checking .IdleConnTimeout") - assert.Equal(t, old.ResponseHeaderTimeout, newT.ResponseHeaderTimeout, "when checking .ResponseHeaderTimeout") - assert.Equal(t, old.ExpectContinueTimeout, newT.ExpectContinueTimeout, "when checking .ExpectContinueTimeout") - assert.Equal(t, old.TLSNextProto, newT.TLSNextProto, "when checking .TLSNextProto") - assert.Equal(t, old.MaxResponseHeaderBytes, newT.MaxResponseHeaderBytes, "when checking .MaxResponseHeaderBytes") -} - func TestCleanAuth(t *testing.T) { for _, test := range []struct { in string diff --git a/lib/structs/structs.go b/lib/structs/structs.go new file mode 100644 index 000000000..b1a781bd8 --- /dev/null +++ b/lib/structs/structs.go @@ -0,0 +1,57 @@ +// Package structs is for manipulating structures with reflection +package structs + +import ( + "reflect" +) + +// SetFrom sets the public members of a from b +// +// a and b should be pointers to structs +// +// a can be a different type from b +// +// Only the Fields which have the same name and assignable type on a +// and b will be set. +// +// This is useful for copying between almost identical structures that +// are requently present in auto generated code for cloud storage +// interfaces. +func SetFrom(a, b interface{}) { + ta := reflect.TypeOf(a).Elem() + tb := reflect.TypeOf(b).Elem() + va := reflect.ValueOf(a).Elem() + vb := reflect.ValueOf(b).Elem() + for i := 0; i < tb.NumField(); i++ { + bField := vb.Field(i) + tbField := tb.Field(i) + name := tbField.Name + aField := va.FieldByName(name) + taField, found := ta.FieldByName(name) + if found && aField.IsValid() && bField.IsValid() && aField.CanSet() && tbField.Type.AssignableTo(taField.Type) { + aField.Set(bField) + } + } +} + +// SetDefaults for a from b +// +// a and b should be pointers to the same kind of struct +// +// This copies the public members only from b to a. This is useful if +// you can't just use a struct copy because it contains a private +// mutex, eg as http.Transport. +func SetDefaults(a, b interface{}) { + pt := reflect.TypeOf(a) + t := pt.Elem() + va := reflect.ValueOf(a).Elem() + vb := reflect.ValueOf(b).Elem() + for i := 0; i < t.NumField(); i++ { + aField := va.Field(i) + // Set a from b if it is public + if aField.CanSet() { + bField := vb.Field(i) + aField.Set(bField) + } + } +} diff --git a/lib/structs/structs_test.go b/lib/structs/structs_test.go new file mode 100644 index 000000000..8ac8f12e6 --- /dev/null +++ b/lib/structs/structs_test.go @@ -0,0 +1,112 @@ +package structs + +import ( + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +// returns the "%p" representation of the thing passed in +func ptr(p interface{}) string { + return fmt.Sprintf("%p", p) +} + +func TestSetDefaults(t *testing.T) { + old := http.DefaultTransport.(*http.Transport) + newT := new(http.Transport) + SetDefaults(newT, old) + // Can't use assert.Equal or reflect.DeepEqual for this as it has functions in + // Check functions by comparing the "%p" representations of them + assert.Equal(t, ptr(old.Proxy), ptr(newT.Proxy), "when checking .Proxy") + assert.Equal(t, ptr(old.DialContext), ptr(newT.DialContext), "when checking .DialContext") + // Check the other public fields + assert.Equal(t, ptr(old.Dial), ptr(newT.Dial), "when checking .Dial") + assert.Equal(t, ptr(old.DialTLS), ptr(newT.DialTLS), "when checking .DialTLS") + assert.Equal(t, old.TLSClientConfig, newT.TLSClientConfig, "when checking .TLSClientConfig") + assert.Equal(t, old.TLSHandshakeTimeout, newT.TLSHandshakeTimeout, "when checking .TLSHandshakeTimeout") + assert.Equal(t, old.DisableKeepAlives, newT.DisableKeepAlives, "when checking .DisableKeepAlives") + assert.Equal(t, old.DisableCompression, newT.DisableCompression, "when checking .DisableCompression") + assert.Equal(t, old.MaxIdleConns, newT.MaxIdleConns, "when checking .MaxIdleConns") + assert.Equal(t, old.MaxIdleConnsPerHost, newT.MaxIdleConnsPerHost, "when checking .MaxIdleConnsPerHost") + assert.Equal(t, old.IdleConnTimeout, newT.IdleConnTimeout, "when checking .IdleConnTimeout") + assert.Equal(t, old.ResponseHeaderTimeout, newT.ResponseHeaderTimeout, "when checking .ResponseHeaderTimeout") + assert.Equal(t, old.ExpectContinueTimeout, newT.ExpectContinueTimeout, "when checking .ExpectContinueTimeout") + assert.Equal(t, old.TLSNextProto, newT.TLSNextProto, "when checking .TLSNextProto") + assert.Equal(t, old.MaxResponseHeaderBytes, newT.MaxResponseHeaderBytes, "when checking .MaxResponseHeaderBytes") +} + +type aType struct { + Matching string + OnlyA string + MatchingInt int + DifferentType string +} + +type bType struct { + Matching string + OnlyB string + MatchingInt int + DifferentType int + Unused string +} + +func TestSetFrom(t *testing.T) { + a := aType{ + Matching: "a", + OnlyA: "onlyA", + MatchingInt: 1, + DifferentType: "suprise", + } + + b := bType{ + Matching: "b", + OnlyB: "onlyB", + MatchingInt: 2, + DifferentType: 7, + Unused: "Ha", + } + bBefore := b + + SetFrom(&a, &b) + + assert.Equal(t, aType{ + Matching: "b", + OnlyA: "onlyA", + MatchingInt: 2, + DifferentType: "suprise", + }, a) + + assert.Equal(t, bBefore, b) +} + +func TestSetFromReversed(t *testing.T) { + a := aType{ + Matching: "a", + OnlyA: "onlyA", + MatchingInt: 1, + DifferentType: "suprise", + } + aBefore := a + + b := bType{ + Matching: "b", + OnlyB: "onlyB", + MatchingInt: 2, + DifferentType: 7, + Unused: "Ha", + } + + SetFrom(&b, &a) + + assert.Equal(t, bType{ + Matching: "a", + OnlyB: "onlyB", + MatchingInt: 1, + DifferentType: 7, + Unused: "Ha", + }, b) + + assert.Equal(t, aBefore, a) +}