From 513c1ff2211924517b6015ebdbb995e8a10c028c Mon Sep 17 00:00:00 2001 From: Tom Thorogood Date: Fri, 4 Jan 2019 20:49:42 +1030 Subject: [PATCH] Simplify and unify various returns (#893) --- client.go | 34 +++++++--------------------------- labels.go | 3 +-- msg.go | 12 +++++++++--- privaterr.go | 15 ++++++--------- server.go | 8 +++----- smimea.go | 5 +---- tlsa.go | 5 +---- tsig.go | 7 +++---- udp_windows.go | 6 ++---- xfr.go | 18 ++++++------------ 10 files changed, 39 insertions(+), 74 deletions(-) diff --git a/client.go b/client.go index 770a946c..007a0f84 100644 --- a/client.go +++ b/client.go @@ -325,11 +325,7 @@ func (co *Conn) Read(p []byte) (n int, err error) { return tcpRead(r, p[:l]) } // UDP connection - n, err = co.Conn.Read(p) - if err != nil { - return n, err - } - return n, err + return co.Conn.Read(p) } // WriteMsg sends a message through the connection co. @@ -351,10 +347,8 @@ func (co *Conn) WriteMsg(m *Msg) (err error) { if err != nil { return err } - if _, err = co.Write(out); err != nil { - return err - } - return nil + _, err = co.Write(out) + return err } // Write implements the net.Conn Write method. @@ -376,8 +370,7 @@ func (co *Conn) Write(p []byte) (n int, err error) { n, err := io.Copy(w, bytes.NewReader(p)) return int(n), err } - n, err = co.Conn.Write(p) - return n, err + return co.Conn.Write(p) } // Return the appropriate timeout for a specific request @@ -444,11 +437,7 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) { // DialTimeout acts like Dial but takes a timeout. func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) { client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}} - conn, err = client.Dial(address) - if err != nil { - return nil, err - } - return conn, nil + return client.Dial(address) } // DialWithTLS connects to the address on the named network with TLS. @@ -457,12 +446,7 @@ func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, er network += "-tls" } client := Client{Net: network, TLSConfig: tlsConfig} - conn, err = client.Dial(address) - - if err != nil { - return nil, err - } - return conn, nil + return client.Dial(address) } // DialTimeoutWithTLS acts like DialWithTLS but takes a timeout. @@ -471,11 +455,7 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout network += "-tls" } client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}, TLSConfig: tlsConfig} - conn, err = client.Dial(address) - if err != nil { - return nil, err - } - return conn, nil + return client.Dial(address) } // ExchangeContext acts like Exchange, but honors the deadline on the provided diff --git a/labels.go b/labels.go index bc182ba3..ca8c2045 100644 --- a/labels.go +++ b/labels.go @@ -36,8 +36,7 @@ func SplitDomainName(s string) (labels []string) { } } - labels = append(labels, s[begin:fqdnEnd]) - return labels + return append(labels, s[begin:fqdnEnd]) } // CompareDomainName compares the names s1 and s2 and diff --git a/msg.go b/msg.go index bfd0011b..5191fc06 100644 --- a/msg.go +++ b/msg.go @@ -1072,7 +1072,7 @@ func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, boo } // Copy returns a new RR which is a deep-copy of r. -func Copy(r RR) RR { r1 := r.copy(); return r1 } +func Copy(r RR) RR { return r.copy() } // Len returns the length (in octets) of the uncompressed RR in wire format. func Len(r RR) int { return r.len(0, nil) } @@ -1187,7 +1187,10 @@ func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress return off, err } off, err = packUint16(dh.Arcount, msg, off) - return off, err + if err != nil { + return off, err + } + return off, nil } func unpackMsgHdr(msg []byte, off int) (Header, int, error) { @@ -1216,7 +1219,10 @@ func unpackMsgHdr(msg []byte, off int) (Header, int, error) { return dh, off, err } dh.Arcount, off, err = unpackUint16(msg, off) - return dh, off, err + if err != nil { + return dh, off, err + } + return dh, off, nil } // setHdr set the header in the dns using the binary data in dh. diff --git a/privaterr.go b/privaterr.go index 28c41d1c..e84af220 100644 --- a/privaterr.go +++ b/privaterr.go @@ -39,11 +39,12 @@ func mkPrivateRR(rrtype uint16) *PrivateRR { } anyrr := rrfunc() - switch rr := anyrr.(type) { - case *PrivateRR: - return rr + rr, ok := anyrr.(*PrivateRR) + if !ok { + panic(fmt.Sprintf("dns: RR is not a PrivateRR, TypeToRR[%d] generator returned %T", rrtype, anyrr)) } - panic(fmt.Sprintf("dns: RR is not a PrivateRR, TypeToRR[%d] generator returned %T", rrtype, anyrr)) + + return rr } // Header return the RR header of r. @@ -82,11 +83,7 @@ func (r *PrivateRR) pack(msg []byte, off int, compression compressionMap, compre func (r *PrivateRR) unpack(msg []byte, off int) (int, error) { off1, err := r.Data.Unpack(msg[off:]) off += off1 - if err != nil { - return off, err - } - - return off, nil + return off, err } // PrivateHandle registers a private resource record type. It requires diff --git a/server.go b/server.go index c9f0c533..88240370 100644 --- a/server.go +++ b/server.go @@ -463,11 +463,10 @@ var testShutdownNotify *sync.Cond // getReadTimeout is a helper func to use system timeout if server did not intend to change it. func (srv *Server) getReadTimeout() time.Duration { - rtimeout := dnsTimeout if srv.ReadTimeout != 0 { - rtimeout = srv.ReadTimeout + return srv.ReadTimeout } - return rtimeout + return dnsTimeout } // serveTCP starts a TCP listener for the server. @@ -783,8 +782,7 @@ func (w *response) Write(m []byte) (int, error) { switch { case w.udp != nil: - n, err := WriteToSessionUDP(w.udp, m, w.udpSession) - return n, err + return WriteToSessionUDP(w.udp, m, w.udpSession) case w.tcp != nil: lm := len(m) if lm < 2 { diff --git a/smimea.go b/smimea.go index 4e7ded4b..89f09f0d 100644 --- a/smimea.go +++ b/smimea.go @@ -14,10 +14,7 @@ func (r *SMIMEA) Sign(usage, selector, matchingType int, cert *x509.Certificate) r.MatchingType = uint8(matchingType) r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert) - if err != nil { - return err - } - return nil + return err } // Verify verifies a SMIMEA record against an SSL certificate. If it is OK diff --git a/tlsa.go b/tlsa.go index 431e2fb5..4e07983b 100644 --- a/tlsa.go +++ b/tlsa.go @@ -14,10 +14,7 @@ func (r *TLSA) Sign(usage, selector, matchingType int, cert *x509.Certificate) ( r.MatchingType = uint8(matchingType) r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert) - if err != nil { - return err - } - return nil + return err } // Verify verifies a TLSA record against an SSL certificate. If it is OK diff --git a/tsig.go b/tsig.go index 91b69d58..c98fd166 100644 --- a/tsig.go +++ b/tsig.go @@ -134,12 +134,11 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s t.OrigId = m.Id tbuf := make([]byte, Len(t)) - if off, err := PackRR(t, tbuf, 0, nil, false); err == nil { - tbuf = tbuf[:off] // reset to actual size used - } else { + off, err := PackRR(t, tbuf, 0, nil, false) + if err != nil { return nil, "", err } - mbuf = append(mbuf, tbuf...) + mbuf = append(mbuf, tbuf[:off]...) // Update the ArCount directly in the buffer. binary.BigEndian.PutUint16(mbuf[10:], uint16(len(m.Extra)+1)) diff --git a/udp_windows.go b/udp_windows.go index 6778c3c6..e7dd8ca3 100644 --- a/udp_windows.go +++ b/udp_windows.go @@ -20,15 +20,13 @@ func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) { if err != nil { return n, nil, err } - session := &SessionUDP{raddr.(*net.UDPAddr)} - return n, session, err + return n, &SessionUDP{raddr.(*net.UDPAddr)}, err } // WriteToSessionUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr. // TODO(fastest963): Once go1.10 is released, use WriteMsgUDP. func WriteToSessionUDP(conn *net.UDPConn, b []byte, session *SessionUDP) (int, error) { - n, err := conn.WriteTo(b, session.raddr) - return n, err + return conn.WriteTo(b, session.raddr) } // TODO(fastest963): Once go1.10 is released and we can use *MsgUDP methods diff --git a/xfr.go b/xfr.go index f2b5902b..6e36577b 100644 --- a/xfr.go +++ b/xfr.go @@ -243,24 +243,18 @@ func (t *Transfer) WriteMsg(m *Msg) (err error) { if err != nil { return err } - if _, err = t.Write(out); err != nil { - return err - } - return nil + _, err = t.Write(out) + return err } func isSOAFirst(in *Msg) bool { - if len(in.Answer) > 0 { - return in.Answer[0].Header().Rrtype == TypeSOA - } - return false + return len(in.Answer) > 0 && + in.Answer[0].Header().Rrtype == TypeSOA } func isSOALast(in *Msg) bool { - if len(in.Answer) > 0 { - return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA - } - return false + return len(in.Answer) > 0 && + in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA } const errXFR = "bad xfr rcode: %d"