2025-12-23 12:15:28 +11:00

382 lines
8.3 KiB
Go

// SPDX-License-Identifier: Apache-2.0
// Copyright 2025 Suyono
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package autoreload
import (
"context"
"crypto/x509"
"encoding/json"
"encoding/pem"
"fmt"
"os"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/caddyconfig"
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
"go.uber.org/zap"
)
type AutoReload struct {
mtx *sync.RWMutex
certs map[string]*x509.Certificate
ctx context.Context
cancel context.CancelFunc
started bool
caddyfile string
interval caddy.Duration
ticker *time.Ticker
logger *zap.Logger
}
type AutoReloadModule struct {
Interval caddy.Duration `json:"interval,omitempty"`
Caddyfile string `json:"caddyfile,omitempty"`
logger *zap.Logger
app *AutoReload
}
const (
DEFAULT_INTERVAL = caddy.Duration(time.Hour)
)
var ar *AutoReload
func init() {
ctx, cancel := context.WithCancel(context.Background())
ar = &AutoReload{
mtx: new(sync.RWMutex),
ctx: ctx,
cancel: cancel,
certs: make(map[string]*x509.Certificate),
interval: DEFAULT_INTERVAL,
}
caddy.RegisterModule(AutoReloadModule{})
httpcaddyfile.RegisterGlobalOption("auto_reload", parseAutoReloadModuleCaddyfile)
}
func GetAutoReload() *AutoReload {
return ar
}
func (ar *AutoReload) AddCertPath(path string) {
ar.mtx.Lock()
defer ar.mtx.Unlock()
if _, ok := ar.certs[path]; !ok {
ar.certs[path] = nil
}
}
func (ar *AutoReload) SetParam(caddyfile string, interval caddy.Duration, logger *zap.Logger) {
ar.mtx.Lock()
defer ar.mtx.Unlock()
if int64(interval) <= int64(0) {
interval = DEFAULT_INTERVAL
}
ar.caddyfile = caddyfile
ar.interval = interval
ar.logger = logger
ar.logger.Debug("AutoReload: SetParam")
if ar.started && ar.ticker != nil {
ar.ticker.Reset(time.Duration(interval))
}
}
func (ar *AutoReload) run() {
ar.logger.Debug("AutoReload: entering run()")
if !func() bool {
ar.mtx.Lock()
defer ar.mtx.Unlock()
if ar.started {
return false
}
ar.started = true
if ar.ticker != nil {
ar.ticker.Reset(time.Duration(ar.interval))
} else {
ar.ticker = time.NewTicker(time.Duration(ar.interval))
}
return true
}() {
return
}
defer func() {
ar.mtx.Lock()
defer ar.mtx.Unlock()
ar.started = false
if ar.ticker != nil {
ar.ticker.Stop()
}
ar.logger.Debug("AutoReload: shutting down")
}()
ar.logger.Debug("run with config", zap.String("struct", fmt.Sprintf("%#v", ar)))
ar.logger.Info("auto_reload: started")
defer ar.cancel()
ar.checkCertificates()
ar.logger.Debug("after first pass: run with config", zap.String("struct", fmt.Sprintf("%#v", ar)))
for {
select {
case <-ar.ctx.Done():
//TODO: shutdown
return
case <-ar.ticker.C:
if ar.checkCertificates() {
if err := ar.reload(); err != nil {
ar.logger.Error("auto_reload: reload error", zap.Error(err))
}
}
}
}
}
func (ar *AutoReload) checkCertificates() bool {
wg := new(sync.WaitGroup)
flag := new(atomic.Bool)
flag.Store(false)
func() {
ar.mtx.RLock()
defer ar.mtx.RUnlock()
ar.logger.Info("auto_reload: checking certificates")
for path, cert := range ar.certs {
ar.logger.Debug("checkCertificates", zap.String("cert file", path))
wg.Go(func() {
ar.checkCert(path, cert, flag)
})
}
}()
wg.Wait()
return flag.Load()
}
func (ar *AutoReload) checkCert(path string, cert *x509.Certificate, needReloadFlag *atomic.Bool) {
if ar.ctx.Err() != nil {
return
}
b, err := os.ReadFile(path)
if err != nil {
ar.logger.Error("auto_reload: failed to read cert file", zap.String("cert file", path), zap.Error(err))
return
}
p, rest := pem.Decode(b)
if p == nil {
ar.logger.Error("auto_reload: failed to decode cert into pem")
return
}
_ = rest
var (
certFile *x509.Certificate
)
if certFile, err = x509.ParseCertificate(p.Bytes); err != nil {
ar.logger.Error("auto_reload: failed to parse x509 certificate", zap.String("cert file", path), zap.Error(err))
return
}
if cert == nil {
ar.mtx.Lock()
defer ar.mtx.Unlock()
ar.certs[path] = certFile
return
}
if cert.Subject.CommonName != certFile.Subject.CommonName {
ar.logger.Warn("auto_reload: mismatch common name", zap.String("in-memory", cert.Subject.CommonName), zap.String("cert file", certFile.Subject.CommonName))
}
n := len(cert.DNSNames)
if len(cert.DNSNames) != len(certFile.DNSNames) {
ar.logger.Warn("auto_reload: mismatch number of dns aliases")
if len(certFile.DNSNames) < len(cert.DNSNames) {
n = len(certFile.DNSNames)
}
}
certDNS := make([]string, len(cert.DNSNames))
copy(certDNS, cert.DNSNames)
slices.Sort(certDNS)
certFileDNS := make([]string, len(certFile.DNSNames))
copy(certFileDNS, certFile.DNSNames)
slices.Sort(certFileDNS)
for i := 0; i < n; i++ {
if certDNS[i] != certFileDNS[i] {
ar.logger.Warn("auto_reload: mismatch dns entry", zap.String("in-memory", certDNS[i]), zap.String("cert file", certFileDNS[i]))
}
}
if !certFile.NotAfter.Equal(cert.NotAfter) {
needReloadFlag.Store(true)
ar.mtx.Lock()
defer ar.mtx.Unlock()
ar.certs[path] = certFile
}
}
func (ar *AutoReload) reload() error {
adapter := caddyconfig.GetAdapter("caddyfile")
if adapter == nil {
return fmt.Errorf("cannot get caddyfile adapter")
}
var (
b []byte
warns []caddyconfig.Warning
err error
)
if b, err = os.ReadFile(ar.caddyfile); err != nil {
return fmt.Errorf("cannot read %s: %w", ar.caddyfile, err)
}
b, warns, err = adapter.Adapt(b, map[string]any{
"filename": ar.caddyfile,
})
for _, w := range warns {
ar.logger.Warn("auto_reload: adapt warning", zap.String("warning", w.String()))
}
if err != nil {
return fmt.Errorf("adapt caddyfile failed: %w", err)
}
if err = caddy.Load(b, true); err != nil {
return fmt.Errorf("reload config failed: %w", err)
}
return nil
}
func (AutoReloadModule) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "app.auto_reload",
New: func() caddy.Module {
return &AutoReloadModule{
app: GetAutoReload(),
}
},
}
}
func (arm *AutoReloadModule) Provision(ctx caddy.Context) error {
arm.logger = ctx.Logger(arm)
// useful for debugging
// arm.logger = arm.logger.WithOptions(zap.AddCaller(), zap.AddCallerSkip(1))
arm.logger = arm.logger.WithOptions(zap.AddCaller())
arm.logger.Debug("AutoReloadModule: entering Provision()")
if arm.Caddyfile == "" {
return fmt.Errorf("auto_reload: caddyfile is required")
}
if arm.app == nil {
arm.app = GetAutoReload()
}
arm.app.SetParam(arm.Caddyfile, arm.Interval, arm.logger)
return nil
}
func (arm *AutoReloadModule) Start() error {
arm.logger.Debug("AutoReloadModule: Start()")
go arm.app.run()
return nil
}
func (arm *AutoReloadModule) Stop() error {
arm.logger.Debug("AutoReloadModule: Stop()")
arm.app.cancel()
return nil
}
func parseAutoReloadModuleCaddyfile(d *caddyfile.Dispenser, existing any) (any, error) {
if existing != nil {
return nil, fmt.Errorf("auto_reload must be defined once")
}
arm := new(AutoReloadModule) // no need to inject AutoReload here as we're going to marshal it into json
var (
dur time.Duration
err error
b []byte
)
for d.Next() {
for d.NextBlock(0) {
switch d.Val() {
case "interval":
if !d.NextArg() {
return nil, d.ArgErr()
}
if dur, err = caddy.ParseDuration(d.Val()); err != nil {
return nil, d.Errf("invalid interval %q: %v", d.Val(), err)
}
arm.Interval = caddy.Duration(dur)
case "caddyfile":
if !d.NextArg() {
return nil, d.ArgErr()
}
arm.Caddyfile = d.Val()
default:
return nil, d.Errf("unrecognized subdirective %q", d.Val())
}
}
}
if b, err = json.Marshal(arm); err != nil {
return nil, err
}
return httpcaddyfile.App{
Name: "app.auto_reload",
Value: b,
}, nil
}