diff --git a/listeners.go b/listeners.go index 14a5c499..4c851bd1 100644 --- a/listeners.go +++ b/listeners.go @@ -449,7 +449,11 @@ func ListenQUIC(ln net.PacketConn, tlsConf *tls.Config, activeRequests *int64) ( lnKey := listenerKey("quic+"+ln.LocalAddr().Network(), ln.LocalAddr().String()) sharedEarlyListener, _, err := listenerPool.LoadOrNew(lnKey, func() (Destructor, error) { - earlyLn, err := quic.ListenEarly(ln, http3.ConfigureTLSConfig(tlsConf), &quic.Config{ + sqtc := newSharedQUICTLSConfig(tlsConf) + // http3.ConfigureTLSConfig only uses this field and tls App sets this field as well + //nolint:gosec + quicTlsConfig := &tls.Config{GetConfigForClient: sqtc.getConfigForClient} + earlyLn, err := quic.ListenEarly(ln, http3.ConfigureTLSConfig(quicTlsConfig), &quic.Config{ Allow0RTT: func(net.Addr) bool { return true }, RequireAddressValidation: func(clientAddr net.Addr) bool { var highLoad bool @@ -462,12 +466,16 @@ func ListenQUIC(ln net.PacketConn, tlsConf *tls.Config, activeRequests *int64) ( if err != nil { return nil, err } - return &sharedQuicListener{EarlyListener: earlyLn, key: lnKey}, nil + return &sharedQuicListener{EarlyListener: earlyLn, sqtc: sqtc, key: lnKey}, nil }) if err != nil { return nil, err } + sql := sharedEarlyListener.(*sharedQuicListener) + // add current tls.Config to sqtc, so GetConfigForClient will always return the latest tls.Config in case of context cancellation + ctx, cancel := sql.sqtc.addTLSConfig(tlsConf) + // TODO: to serve QUIC over a unix socket, currently we need to hold onto // the underlying net.PacketConn (which we wrap as unixConn to keep count // of closes) because closing the quic.EarlyListener doesn't actually close @@ -479,9 +487,8 @@ func ListenQUIC(ln net.PacketConn, tlsConf *tls.Config, activeRequests *int64) ( unix = uc } - ctx, cancel := context.WithCancel(context.Background()) return &fakeCloseQuicListener{ - sharedQuicListener: sharedEarlyListener.(*sharedQuicListener), + sharedQuicListener: sql, uc: unix, context: ctx, contextCancel: cancel, @@ -494,10 +501,77 @@ func ListenerUsage(network, addr string) int { return count } +// contextAndCancelFunc groups context and its cancelFunc +type contextAndCancelFunc struct { + context.Context + context.CancelFunc +} + +// sharedQUICTLSConfig manages GetConfigForClient +// see issue: https://github.com/caddyserver/caddy/pull/4849 +type sharedQUICTLSConfig struct { + rmu sync.RWMutex + tlsConfs map[*tls.Config]contextAndCancelFunc + activeTlsConf *tls.Config +} + +// newSharedQUICTLSConfig creates a new sharedQUICTLSConfig +func newSharedQUICTLSConfig(tlsConfig *tls.Config) *sharedQUICTLSConfig { + sqtc := &sharedQUICTLSConfig{ + tlsConfs: make(map[*tls.Config]contextAndCancelFunc), + activeTlsConf: tlsConfig, + } + sqtc.addTLSConfig(tlsConfig) + return sqtc +} + +// getConfigForClient is used as tls.Config's GetConfigForClient field +func (sqtc *sharedQUICTLSConfig) getConfigForClient(ch *tls.ClientHelloInfo) (*tls.Config, error) { + sqtc.rmu.RLock() + defer sqtc.rmu.RUnlock() + return sqtc.activeTlsConf.GetConfigForClient(ch) +} + +// addTLSConfig adds tls.Config to the map if not present and returns the corresponding context and its cancelFunc +// so that when cancelled, the active tls.Config will change +func (sqtc *sharedQUICTLSConfig) addTLSConfig(tlsConfig *tls.Config) (context.Context, context.CancelFunc) { + sqtc.rmu.Lock() + defer sqtc.rmu.Unlock() + + if cacc, ok := sqtc.tlsConfs[tlsConfig]; ok { + return cacc.Context, cacc.CancelFunc + } + + ctx, cancel := context.WithCancel(context.Background()) + wrappedCancel := func() { + cancel() + + sqtc.rmu.Lock() + defer sqtc.rmu.Unlock() + + delete(sqtc.tlsConfs, tlsConfig) + if sqtc.activeTlsConf == tlsConfig { + // select another tls.Config, if there is none, + // related sharedQuicListener will be destroyed anyway + for tc := range sqtc.tlsConfs { + sqtc.activeTlsConf = tc + break + } + } + } + sqtc.tlsConfs[tlsConfig] = contextAndCancelFunc{ctx, wrappedCancel} + // there should be at most 2 tls.Configs + if len(sqtc.tlsConfs) > 2 { + Log().Warn("quic listener tls configs are more than 2", zap.Int("number of configs", len(sqtc.tlsConfs))) + } + return ctx, wrappedCancel +} + // sharedQuicListener is like sharedListener, but for quic.EarlyListeners. type sharedQuicListener struct { quic.EarlyListener - key string + sqtc *sharedQUICTLSConfig + key string } // Destruct closes the underlying QUIC listener.