382 lines
8.3 KiB
Go
382 lines
8.3 KiB
Go
// SPDX-License-Identifier: Apache-2.0
|
|
|
|
// Copyright 2025 Suyono
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package autoreload
|
|
|
|
import (
|
|
"context"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"os"
|
|
"slices"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/caddyserver/caddy/v2"
|
|
"github.com/caddyserver/caddy/v2/caddyconfig"
|
|
"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
|
|
"github.com/caddyserver/caddy/v2/caddyconfig/httpcaddyfile"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type AutoReload struct {
|
|
mtx *sync.RWMutex
|
|
certs map[string]*x509.Certificate
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
started bool
|
|
caddyfile string
|
|
interval caddy.Duration
|
|
ticker *time.Ticker
|
|
logger *zap.Logger
|
|
}
|
|
|
|
type AutoReloadModule struct {
|
|
Interval caddy.Duration `json:"interval,omitempty"`
|
|
Caddyfile string `json:"caddyfile,omitempty"`
|
|
logger *zap.Logger
|
|
app *AutoReload
|
|
}
|
|
|
|
const (
|
|
DEFAULT_INTERVAL = caddy.Duration(time.Hour)
|
|
)
|
|
|
|
var ar *AutoReload
|
|
|
|
func init() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
ar = &AutoReload{
|
|
mtx: new(sync.RWMutex),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
certs: make(map[string]*x509.Certificate),
|
|
interval: DEFAULT_INTERVAL,
|
|
}
|
|
|
|
caddy.RegisterModule(AutoReloadModule{})
|
|
httpcaddyfile.RegisterGlobalOption("auto_reload", parseAutoReloadModuleCaddyfile)
|
|
}
|
|
|
|
func GetAutoReload() *AutoReload {
|
|
return ar
|
|
}
|
|
|
|
func (ar *AutoReload) AddCertPath(path string) {
|
|
ar.mtx.Lock()
|
|
defer ar.mtx.Unlock()
|
|
|
|
if _, ok := ar.certs[path]; !ok {
|
|
ar.certs[path] = nil
|
|
}
|
|
}
|
|
|
|
func (ar *AutoReload) SetParam(caddyfile string, interval caddy.Duration, logger *zap.Logger) {
|
|
ar.mtx.Lock()
|
|
defer ar.mtx.Unlock()
|
|
|
|
if int64(interval) <= int64(0) {
|
|
interval = DEFAULT_INTERVAL
|
|
}
|
|
|
|
ar.caddyfile = caddyfile
|
|
ar.interval = interval
|
|
ar.logger = logger
|
|
|
|
ar.logger.Debug("AutoReload: SetParam")
|
|
|
|
if ar.started && ar.ticker != nil {
|
|
ar.ticker.Reset(time.Duration(interval))
|
|
}
|
|
}
|
|
|
|
func (ar *AutoReload) run() {
|
|
ar.logger.Debug("AutoReload: entering run()")
|
|
if !func() bool {
|
|
ar.mtx.Lock()
|
|
defer ar.mtx.Unlock()
|
|
|
|
if ar.started {
|
|
return false
|
|
}
|
|
|
|
ar.started = true
|
|
if ar.ticker != nil {
|
|
ar.ticker.Reset(time.Duration(ar.interval))
|
|
} else {
|
|
ar.ticker = time.NewTicker(time.Duration(ar.interval))
|
|
}
|
|
|
|
return true
|
|
}() {
|
|
return
|
|
}
|
|
|
|
defer func() {
|
|
ar.mtx.Lock()
|
|
defer ar.mtx.Unlock()
|
|
|
|
ar.started = false
|
|
if ar.ticker != nil {
|
|
ar.ticker.Stop()
|
|
}
|
|
ar.logger.Debug("AutoReload: shutting down")
|
|
}()
|
|
|
|
ar.logger.Debug("run with config", zap.String("struct", fmt.Sprintf("%#v", ar)))
|
|
ar.logger.Info("auto_reload: started")
|
|
|
|
defer ar.cancel()
|
|
ar.checkCertificates()
|
|
|
|
ar.logger.Debug("after first pass: run with config", zap.String("struct", fmt.Sprintf("%#v", ar)))
|
|
|
|
for {
|
|
select {
|
|
case <-ar.ctx.Done():
|
|
//TODO: shutdown
|
|
return
|
|
case <-ar.ticker.C:
|
|
if ar.checkCertificates() {
|
|
if err := ar.reload(); err != nil {
|
|
ar.logger.Error("auto_reload: reload error", zap.Error(err))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ar *AutoReload) checkCertificates() bool {
|
|
wg := new(sync.WaitGroup)
|
|
flag := new(atomic.Bool)
|
|
flag.Store(false)
|
|
|
|
func() {
|
|
ar.mtx.RLock()
|
|
defer ar.mtx.RUnlock()
|
|
|
|
ar.logger.Info("auto_reload: checking certificates")
|
|
for path, cert := range ar.certs {
|
|
ar.logger.Debug("checkCertificates", zap.String("cert file", path))
|
|
wg.Go(func() {
|
|
ar.checkCert(path, cert, flag)
|
|
})
|
|
}
|
|
}()
|
|
|
|
wg.Wait()
|
|
|
|
return flag.Load()
|
|
}
|
|
|
|
func (ar *AutoReload) checkCert(path string, cert *x509.Certificate, needReloadFlag *atomic.Bool) {
|
|
if ar.ctx.Err() != nil {
|
|
return
|
|
}
|
|
|
|
b, err := os.ReadFile(path)
|
|
if err != nil {
|
|
ar.logger.Error("auto_reload: failed to read cert file", zap.String("cert file", path), zap.Error(err))
|
|
return
|
|
}
|
|
|
|
p, rest := pem.Decode(b)
|
|
if p == nil {
|
|
ar.logger.Error("auto_reload: failed to decode cert into pem")
|
|
return
|
|
}
|
|
_ = rest
|
|
|
|
var (
|
|
certFile *x509.Certificate
|
|
)
|
|
if certFile, err = x509.ParseCertificate(p.Bytes); err != nil {
|
|
ar.logger.Error("auto_reload: failed to parse x509 certificate", zap.String("cert file", path), zap.Error(err))
|
|
return
|
|
}
|
|
|
|
if cert == nil {
|
|
ar.mtx.Lock()
|
|
defer ar.mtx.Unlock()
|
|
|
|
ar.certs[path] = certFile
|
|
return
|
|
}
|
|
|
|
if cert.Subject.CommonName != certFile.Subject.CommonName {
|
|
ar.logger.Warn("auto_reload: mismatch common name", zap.String("in-memory", cert.Subject.CommonName), zap.String("cert file", certFile.Subject.CommonName))
|
|
}
|
|
|
|
n := len(cert.DNSNames)
|
|
if len(cert.DNSNames) != len(certFile.DNSNames) {
|
|
ar.logger.Warn("auto_reload: mismatch number of dns aliases")
|
|
|
|
if len(certFile.DNSNames) < len(cert.DNSNames) {
|
|
n = len(certFile.DNSNames)
|
|
}
|
|
}
|
|
|
|
certDNS := make([]string, len(cert.DNSNames))
|
|
copy(certDNS, cert.DNSNames)
|
|
slices.Sort(certDNS)
|
|
|
|
certFileDNS := make([]string, len(certFile.DNSNames))
|
|
copy(certFileDNS, certFile.DNSNames)
|
|
slices.Sort(certFileDNS)
|
|
|
|
for i := 0; i < n; i++ {
|
|
if certDNS[i] != certFileDNS[i] {
|
|
ar.logger.Warn("auto_reload: mismatch dns entry", zap.String("in-memory", certDNS[i]), zap.String("cert file", certFileDNS[i]))
|
|
}
|
|
}
|
|
|
|
if !certFile.NotAfter.Equal(cert.NotAfter) {
|
|
needReloadFlag.Store(true)
|
|
|
|
ar.mtx.Lock()
|
|
defer ar.mtx.Unlock()
|
|
|
|
ar.certs[path] = certFile
|
|
}
|
|
}
|
|
|
|
func (ar *AutoReload) reload() error {
|
|
adapter := caddyconfig.GetAdapter("caddyfile")
|
|
if adapter == nil {
|
|
return fmt.Errorf("cannot get caddyfile adapter")
|
|
}
|
|
|
|
var (
|
|
b []byte
|
|
warns []caddyconfig.Warning
|
|
err error
|
|
)
|
|
if b, err = os.ReadFile(ar.caddyfile); err != nil {
|
|
return fmt.Errorf("cannot read %s: %w", ar.caddyfile, err)
|
|
}
|
|
|
|
b, warns, err = adapter.Adapt(b, map[string]any{
|
|
"filename": ar.caddyfile,
|
|
})
|
|
for _, w := range warns {
|
|
ar.logger.Warn("auto_reload: adapt warning", zap.String("warning", w.String()))
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("adapt caddyfile failed: %w", err)
|
|
}
|
|
|
|
if err = caddy.Load(b, true); err != nil {
|
|
return fmt.Errorf("reload config failed: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (AutoReloadModule) CaddyModule() caddy.ModuleInfo {
|
|
return caddy.ModuleInfo{
|
|
ID: "app.auto_reload",
|
|
New: func() caddy.Module {
|
|
return &AutoReloadModule{
|
|
app: GetAutoReload(),
|
|
}
|
|
},
|
|
}
|
|
}
|
|
|
|
func (arm *AutoReloadModule) Provision(ctx caddy.Context) error {
|
|
arm.logger = ctx.Logger(arm)
|
|
|
|
// useful for debugging
|
|
// arm.logger = arm.logger.WithOptions(zap.AddCaller(), zap.AddCallerSkip(1))
|
|
arm.logger = arm.logger.WithOptions(zap.AddCaller())
|
|
|
|
arm.logger.Debug("AutoReloadModule: entering Provision()")
|
|
if arm.Caddyfile == "" {
|
|
return fmt.Errorf("auto_reload: caddyfile is required")
|
|
}
|
|
|
|
if arm.app == nil {
|
|
arm.app = GetAutoReload()
|
|
}
|
|
|
|
arm.app.SetParam(arm.Caddyfile, arm.Interval, arm.logger)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (arm *AutoReloadModule) Start() error {
|
|
arm.logger.Debug("AutoReloadModule: Start()")
|
|
go arm.app.run()
|
|
return nil
|
|
}
|
|
|
|
func (arm *AutoReloadModule) Stop() error {
|
|
arm.logger.Debug("AutoReloadModule: Stop()")
|
|
arm.app.cancel()
|
|
return nil
|
|
}
|
|
|
|
func parseAutoReloadModuleCaddyfile(d *caddyfile.Dispenser, existing any) (any, error) {
|
|
if existing != nil {
|
|
return nil, fmt.Errorf("auto_reload must be defined once")
|
|
}
|
|
|
|
arm := new(AutoReloadModule) // no need to inject AutoReload here as we're going to marshal it into json
|
|
|
|
var (
|
|
dur time.Duration
|
|
err error
|
|
b []byte
|
|
)
|
|
|
|
for d.Next() {
|
|
for d.NextBlock(0) {
|
|
switch d.Val() {
|
|
case "interval":
|
|
if !d.NextArg() {
|
|
return nil, d.ArgErr()
|
|
}
|
|
|
|
if dur, err = caddy.ParseDuration(d.Val()); err != nil {
|
|
return nil, d.Errf("invalid interval %q: %v", d.Val(), err)
|
|
}
|
|
|
|
arm.Interval = caddy.Duration(dur)
|
|
case "caddyfile":
|
|
if !d.NextArg() {
|
|
return nil, d.ArgErr()
|
|
}
|
|
|
|
arm.Caddyfile = d.Val()
|
|
default:
|
|
return nil, d.Errf("unrecognized subdirective %q", d.Val())
|
|
}
|
|
}
|
|
}
|
|
|
|
if b, err = json.Marshal(arm); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return httpcaddyfile.App{
|
|
Name: "app.auto_reload",
|
|
Value: b,
|
|
}, nil
|
|
}
|