// SPDX-License-Identifier: Apache-2.0 // Copyright 2025 Suyono // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package autoreload import ( "context" "crypto/x509" "encoding/json" "encoding/pem" "fmt" "os" "slices" "sync" "sync/atomic" "time" "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig" "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" "github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile" "go.uber.org/zap" ) type AutoReload struct { mtx *sync.RWMutex certs map[string]*x509.Certificate ctx context.Context cancel context.CancelFunc started bool caddyfile string interval caddy.Duration ticker *time.Ticker logger *zap.Logger } type AutoReloadModule struct { Interval caddy.Duration `json:"interval,omitempty"` Caddyfile string `json:"caddyfile,omitempty"` logger *zap.Logger app *AutoReload } const ( DEFAULT_INTERVAL = caddy.Duration(time.Hour) ) var ar *AutoReload func init() { ctx, cancel := context.WithCancel(context.Background()) ar = &AutoReload{ mtx: new(sync.RWMutex), ctx: ctx, cancel: cancel, certs: make(map[string]*x509.Certificate), interval: DEFAULT_INTERVAL, } caddy.RegisterModule(AutoReloadModule{}) httpcaddyfile.RegisterGlobalOption("auto_reload", parseAutoReloadModuleCaddyfile) } func GetAutoReload() *AutoReload { return ar } func (ar *AutoReload) AddCertPath(path string) { ar.mtx.Lock() defer ar.mtx.Unlock() if _, ok := ar.certs[path]; !ok { ar.certs[path] = nil } } func (ar *AutoReload) SetParam(caddyfile string, interval caddy.Duration, logger *zap.Logger) { ar.mtx.Lock() defer ar.mtx.Unlock() if int64(interval) <= int64(0) { interval = DEFAULT_INTERVAL } ar.caddyfile = caddyfile ar.interval = interval ar.logger = logger ar.logger.Debug("AutoReload: SetParam") if ar.started && ar.ticker != nil { ar.ticker.Reset(time.Duration(interval)) } } func (ar *AutoReload) run() { ar.logger.Debug("AutoReload: entering run()") if !func() bool { ar.mtx.Lock() defer ar.mtx.Unlock() if ar.started { return false } ar.started = true if ar.ticker != nil { ar.ticker.Reset(time.Duration(ar.interval)) } else { ar.ticker = time.NewTicker(time.Duration(ar.interval)) } return true }() { return } defer func() { ar.mtx.Lock() defer ar.mtx.Unlock() ar.started = false if ar.ticker != nil { ar.ticker.Stop() } ar.logger.Debug("AutoReload: shutting down") }() ar.logger.Debug("run with config", zap.String("struct", fmt.Sprintf("%#v", ar))) ar.logger.Info("auto_reload: started") defer ar.cancel() ar.checkCertificates() ar.logger.Debug("after first pass: run with config", zap.String("struct", fmt.Sprintf("%#v", ar))) for { select { case <-ar.ctx.Done(): //TODO: shutdown return case <-ar.ticker.C: if ar.checkCertificates() { if err := ar.reload(); err != nil { ar.logger.Error("auto_reload: reload error", zap.Error(err)) } } } } } func (ar *AutoReload) checkCertificates() bool { wg := new(sync.WaitGroup) flag := new(atomic.Bool) flag.Store(false) func() { ar.mtx.RLock() defer ar.mtx.RUnlock() ar.logger.Info("auto_reload: checking certificates") for path, cert := range ar.certs { ar.logger.Debug("checkCertificates", zap.String("cert file", path)) wg.Go(func() { ar.checkCert(path, cert, flag) }) } }() wg.Wait() return flag.Load() } func (ar *AutoReload) checkCert(path string, cert *x509.Certificate, needReloadFlag *atomic.Bool) { if ar.ctx.Err() != nil { return } b, err := os.ReadFile(path) if err != nil { ar.logger.Error("auto_reload: failed to read cert file", zap.String("cert file", path), zap.Error(err)) return } p, rest := pem.Decode(b) if p == nil { ar.logger.Error("auto_reload: failed to decode cert into pem") return } _ = rest var ( certFile *x509.Certificate ) if certFile, err = x509.ParseCertificate(p.Bytes); err != nil { ar.logger.Error("auto_reload: failed to parse x509 certificate", zap.String("cert file", path), zap.Error(err)) return } if cert == nil { ar.mtx.Lock() defer ar.mtx.Unlock() ar.certs[path] = certFile return } if cert.Subject.CommonName != certFile.Subject.CommonName { ar.logger.Warn("auto_reload: mismatch common name", zap.String("in-memory", cert.Subject.CommonName), zap.String("cert file", certFile.Subject.CommonName)) } n := len(cert.DNSNames) if len(cert.DNSNames) != len(certFile.DNSNames) { ar.logger.Warn("auto_reload: mismatch number of dns aliases") if len(certFile.DNSNames) < len(cert.DNSNames) { n = len(certFile.DNSNames) } } certDNS := make([]string, len(cert.DNSNames)) copy(certDNS, cert.DNSNames) slices.Sort(certDNS) certFileDNS := make([]string, len(certFile.DNSNames)) copy(certFileDNS, certFile.DNSNames) slices.Sort(certFileDNS) for i := 0; i < n; i++ { if certDNS[i] != certFileDNS[i] { ar.logger.Warn("auto_reload: mismatch dns entry", zap.String("in-memory", certDNS[i]), zap.String("cert file", certFileDNS[i])) } } if !certFile.NotAfter.Equal(cert.NotAfter) { needReloadFlag.Store(true) ar.mtx.Lock() defer ar.mtx.Unlock() ar.certs[path] = certFile } } func (ar *AutoReload) reload() error { adapter := caddyconfig.GetAdapter("caddyfile") if adapter == nil { return fmt.Errorf("cannot get caddyfile adapter") } var ( b []byte warns []caddyconfig.Warning err error ) if b, err = os.ReadFile(ar.caddyfile); err != nil { return fmt.Errorf("cannot read %s: %w", ar.caddyfile, err) } b, warns, err = adapter.Adapt(b, map[string]any{ "filename": ar.caddyfile, }) for _, w := range warns { ar.logger.Warn("auto_reload: adapt warning", zap.String("warning", w.String())) } if err != nil { return fmt.Errorf("adapt caddyfile failed: %w", err) } if err = caddy.Load(b, true); err != nil { return fmt.Errorf("reload config failed: %w", err) } return nil } func (AutoReloadModule) CaddyModule() caddy.ModuleInfo { return caddy.ModuleInfo{ ID: "app.auto_reload", New: func() caddy.Module { return &AutoReloadModule{ app: GetAutoReload(), } }, } } func (arm *AutoReloadModule) Provision(ctx caddy.Context) error { arm.logger = ctx.Logger(arm) // useful for debugging // arm.logger = arm.logger.WithOptions(zap.AddCaller(), zap.AddCallerSkip(1)) arm.logger = arm.logger.WithOptions(zap.AddCaller()) arm.logger.Debug("AutoReloadModule: entering Provision()") if arm.Caddyfile == "" { return fmt.Errorf("auto_reload: caddyfile is required") } if arm.app == nil { arm.app = GetAutoReload() } arm.app.SetParam(arm.Caddyfile, arm.Interval, arm.logger) return nil } func (arm *AutoReloadModule) Start() error { arm.logger.Debug("AutoReloadModule: Start()") go arm.app.run() return nil } func (arm *AutoReloadModule) Stop() error { arm.logger.Debug("AutoReloadModule: Stop()") arm.app.cancel() return nil } func parseAutoReloadModuleCaddyfile(d *caddyfile.Dispenser, existing any) (any, error) { if existing != nil { return nil, fmt.Errorf("auto_reload must be defined once") } arm := new(AutoReloadModule) // no need to inject AutoReload here as we're going to marshal it into json var ( dur time.Duration err error b []byte ) for d.Next() { for d.NextBlock(0) { switch d.Val() { case "interval": if !d.NextArg() { return nil, d.ArgErr() } if dur, err = caddy.ParseDuration(d.Val()); err != nil { return nil, d.Errf("invalid interval %q: %v", d.Val(), err) } arm.Interval = caddy.Duration(dur) case "caddyfile": if !d.NextArg() { return nil, d.ArgErr() } arm.Caddyfile = d.Val() default: return nil, d.Errf("unrecognized subdirective %q", d.Val()) } } } if b, err = json.Marshal(arm); err != nil { return nil, err } return httpcaddyfile.App{ Name: "app.auto_reload", Value: b, }, nil }