Simplify and unify various returns (#893)

This commit is contained in:
Tom Thorogood 2019-01-04 20:49:42 +10:30 committed by Miek Gieben
parent 57ca5ae8f4
commit 513c1ff221
10 changed files with 39 additions and 74 deletions

View File

@ -325,11 +325,7 @@ func (co *Conn) Read(p []byte) (n int, err error) {
return tcpRead(r, p[:l]) return tcpRead(r, p[:l])
} }
// UDP connection // UDP connection
n, err = co.Conn.Read(p) return co.Conn.Read(p)
if err != nil {
return n, err
}
return n, err
} }
// WriteMsg sends a message through the connection co. // WriteMsg sends a message through the connection co.
@ -351,10 +347,8 @@ func (co *Conn) WriteMsg(m *Msg) (err error) {
if err != nil { if err != nil {
return err return err
} }
if _, err = co.Write(out); err != nil { _, err = co.Write(out)
return err return err
}
return nil
} }
// Write implements the net.Conn Write method. // Write implements the net.Conn Write method.
@ -376,8 +370,7 @@ func (co *Conn) Write(p []byte) (n int, err error) {
n, err := io.Copy(w, bytes.NewReader(p)) n, err := io.Copy(w, bytes.NewReader(p))
return int(n), err return int(n), err
} }
n, err = co.Conn.Write(p) return co.Conn.Write(p)
return n, err
} }
// Return the appropriate timeout for a specific request // Return the appropriate timeout for a specific request
@ -444,11 +437,7 @@ func ExchangeConn(c net.Conn, m *Msg) (r *Msg, err error) {
// DialTimeout acts like Dial but takes a timeout. // DialTimeout acts like Dial but takes a timeout.
func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) { func DialTimeout(network, address string, timeout time.Duration) (conn *Conn, err error) {
client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}} client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}}
conn, err = client.Dial(address) return client.Dial(address)
if err != nil {
return nil, err
}
return conn, nil
} }
// DialWithTLS connects to the address on the named network with TLS. // DialWithTLS connects to the address on the named network with TLS.
@ -457,12 +446,7 @@ func DialWithTLS(network, address string, tlsConfig *tls.Config) (conn *Conn, er
network += "-tls" network += "-tls"
} }
client := Client{Net: network, TLSConfig: tlsConfig} client := Client{Net: network, TLSConfig: tlsConfig}
conn, err = client.Dial(address) return client.Dial(address)
if err != nil {
return nil, err
}
return conn, nil
} }
// DialTimeoutWithTLS acts like DialWithTLS but takes a timeout. // DialTimeoutWithTLS acts like DialWithTLS but takes a timeout.
@ -471,11 +455,7 @@ func DialTimeoutWithTLS(network, address string, tlsConfig *tls.Config, timeout
network += "-tls" network += "-tls"
} }
client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}, TLSConfig: tlsConfig} client := Client{Net: network, Dialer: &net.Dialer{Timeout: timeout}, TLSConfig: tlsConfig}
conn, err = client.Dial(address) return client.Dial(address)
if err != nil {
return nil, err
}
return conn, nil
} }
// ExchangeContext acts like Exchange, but honors the deadline on the provided // ExchangeContext acts like Exchange, but honors the deadline on the provided

View File

@ -36,8 +36,7 @@ func SplitDomainName(s string) (labels []string) {
} }
} }
labels = append(labels, s[begin:fqdnEnd]) return append(labels, s[begin:fqdnEnd])
return labels
} }
// CompareDomainName compares the names s1 and s2 and // CompareDomainName compares the names s1 and s2 and

12
msg.go
View File

@ -1072,7 +1072,7 @@ func compressionLenSearch(c map[string]struct{}, s string, msgOff int) (int, boo
} }
// Copy returns a new RR which is a deep-copy of r. // Copy returns a new RR which is a deep-copy of r.
func Copy(r RR) RR { r1 := r.copy(); return r1 } func Copy(r RR) RR { return r.copy() }
// Len returns the length (in octets) of the uncompressed RR in wire format. // Len returns the length (in octets) of the uncompressed RR in wire format.
func Len(r RR) int { return r.len(0, nil) } func Len(r RR) int { return r.len(0, nil) }
@ -1187,7 +1187,10 @@ func (dh *Header) pack(msg []byte, off int, compression compressionMap, compress
return off, err return off, err
} }
off, err = packUint16(dh.Arcount, msg, off) off, err = packUint16(dh.Arcount, msg, off)
return off, err if err != nil {
return off, err
}
return off, nil
} }
func unpackMsgHdr(msg []byte, off int) (Header, int, error) { func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
@ -1216,7 +1219,10 @@ func unpackMsgHdr(msg []byte, off int) (Header, int, error) {
return dh, off, err return dh, off, err
} }
dh.Arcount, off, err = unpackUint16(msg, off) dh.Arcount, off, err = unpackUint16(msg, off)
return dh, off, err if err != nil {
return dh, off, err
}
return dh, off, nil
} }
// setHdr set the header in the dns using the binary data in dh. // setHdr set the header in the dns using the binary data in dh.

View File

@ -39,11 +39,12 @@ func mkPrivateRR(rrtype uint16) *PrivateRR {
} }
anyrr := rrfunc() anyrr := rrfunc()
switch rr := anyrr.(type) { rr, ok := anyrr.(*PrivateRR)
case *PrivateRR: if !ok {
return rr panic(fmt.Sprintf("dns: RR is not a PrivateRR, TypeToRR[%d] generator returned %T", rrtype, anyrr))
} }
panic(fmt.Sprintf("dns: RR is not a PrivateRR, TypeToRR[%d] generator returned %T", rrtype, anyrr))
return rr
} }
// Header return the RR header of r. // Header return the RR header of r.
@ -82,11 +83,7 @@ func (r *PrivateRR) pack(msg []byte, off int, compression compressionMap, compre
func (r *PrivateRR) unpack(msg []byte, off int) (int, error) { func (r *PrivateRR) unpack(msg []byte, off int) (int, error) {
off1, err := r.Data.Unpack(msg[off:]) off1, err := r.Data.Unpack(msg[off:])
off += off1 off += off1
if err != nil { return off, err
return off, err
}
return off, nil
} }
// PrivateHandle registers a private resource record type. It requires // PrivateHandle registers a private resource record type. It requires

View File

@ -463,11 +463,10 @@ var testShutdownNotify *sync.Cond
// getReadTimeout is a helper func to use system timeout if server did not intend to change it. // getReadTimeout is a helper func to use system timeout if server did not intend to change it.
func (srv *Server) getReadTimeout() time.Duration { func (srv *Server) getReadTimeout() time.Duration {
rtimeout := dnsTimeout
if srv.ReadTimeout != 0 { if srv.ReadTimeout != 0 {
rtimeout = srv.ReadTimeout return srv.ReadTimeout
} }
return rtimeout return dnsTimeout
} }
// serveTCP starts a TCP listener for the server. // serveTCP starts a TCP listener for the server.
@ -783,8 +782,7 @@ func (w *response) Write(m []byte) (int, error) {
switch { switch {
case w.udp != nil: case w.udp != nil:
n, err := WriteToSessionUDP(w.udp, m, w.udpSession) return WriteToSessionUDP(w.udp, m, w.udpSession)
return n, err
case w.tcp != nil: case w.tcp != nil:
lm := len(m) lm := len(m)
if lm < 2 { if lm < 2 {

View File

@ -14,10 +14,7 @@ func (r *SMIMEA) Sign(usage, selector, matchingType int, cert *x509.Certificate)
r.MatchingType = uint8(matchingType) r.MatchingType = uint8(matchingType)
r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert) r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert)
if err != nil { return err
return err
}
return nil
} }
// Verify verifies a SMIMEA record against an SSL certificate. If it is OK // Verify verifies a SMIMEA record against an SSL certificate. If it is OK

View File

@ -14,10 +14,7 @@ func (r *TLSA) Sign(usage, selector, matchingType int, cert *x509.Certificate) (
r.MatchingType = uint8(matchingType) r.MatchingType = uint8(matchingType)
r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert) r.Certificate, err = CertificateToDANE(r.Selector, r.MatchingType, cert)
if err != nil { return err
return err
}
return nil
} }
// Verify verifies a TLSA record against an SSL certificate. If it is OK // Verify verifies a TLSA record against an SSL certificate. If it is OK

View File

@ -134,12 +134,11 @@ func TsigGenerate(m *Msg, secret, requestMAC string, timersOnly bool) ([]byte, s
t.OrigId = m.Id t.OrigId = m.Id
tbuf := make([]byte, Len(t)) tbuf := make([]byte, Len(t))
if off, err := PackRR(t, tbuf, 0, nil, false); err == nil { off, err := PackRR(t, tbuf, 0, nil, false)
tbuf = tbuf[:off] // reset to actual size used if err != nil {
} else {
return nil, "", err return nil, "", err
} }
mbuf = append(mbuf, tbuf...) mbuf = append(mbuf, tbuf[:off]...)
// Update the ArCount directly in the buffer. // Update the ArCount directly in the buffer.
binary.BigEndian.PutUint16(mbuf[10:], uint16(len(m.Extra)+1)) binary.BigEndian.PutUint16(mbuf[10:], uint16(len(m.Extra)+1))

View File

@ -20,15 +20,13 @@ func ReadFromSessionUDP(conn *net.UDPConn, b []byte) (int, *SessionUDP, error) {
if err != nil { if err != nil {
return n, nil, err return n, nil, err
} }
session := &SessionUDP{raddr.(*net.UDPAddr)} return n, &SessionUDP{raddr.(*net.UDPAddr)}, err
return n, session, err
} }
// WriteToSessionUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr. // WriteToSessionUDP acts just like net.UDPConn.WriteTo(), but uses a *SessionUDP instead of a net.Addr.
// TODO(fastest963): Once go1.10 is released, use WriteMsgUDP. // TODO(fastest963): Once go1.10 is released, use WriteMsgUDP.
func WriteToSessionUDP(conn *net.UDPConn, b []byte, session *SessionUDP) (int, error) { func WriteToSessionUDP(conn *net.UDPConn, b []byte, session *SessionUDP) (int, error) {
n, err := conn.WriteTo(b, session.raddr) return conn.WriteTo(b, session.raddr)
return n, err
} }
// TODO(fastest963): Once go1.10 is released and we can use *MsgUDP methods // TODO(fastest963): Once go1.10 is released and we can use *MsgUDP methods

18
xfr.go
View File

@ -243,24 +243,18 @@ func (t *Transfer) WriteMsg(m *Msg) (err error) {
if err != nil { if err != nil {
return err return err
} }
if _, err = t.Write(out); err != nil { _, err = t.Write(out)
return err return err
}
return nil
} }
func isSOAFirst(in *Msg) bool { func isSOAFirst(in *Msg) bool {
if len(in.Answer) > 0 { return len(in.Answer) > 0 &&
return in.Answer[0].Header().Rrtype == TypeSOA in.Answer[0].Header().Rrtype == TypeSOA
}
return false
} }
func isSOALast(in *Msg) bool { func isSOALast(in *Msg) bool {
if len(in.Answer) > 0 { return len(in.Answer) > 0 &&
return in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA in.Answer[len(in.Answer)-1].Header().Rrtype == TypeSOA
}
return false
} }
const errXFR = "bad xfr rcode: %d" const errXFR = "bad xfr rcode: %d"