diff --git a/unitelhttp/transport.go b/unitelhttp/transport.go index 4a6c291..db260d9 100644 --- a/unitelhttp/transport.go +++ b/unitelhttp/transport.go @@ -24,12 +24,28 @@ func WithLogger(l logr.Logger) HTTPTransportOpt { } } +type TracePropagator = func(r *http.Request) bool + +func WithTracePropagation(propagator TracePropagator) HTTPTransportOpt { + return func(t *HTTPTransport) { + t.tracePropagator = propagator + } +} + +func PropagateAllTraces(req *http.Request) bool { + return true +} + +func PropagateNoTraces(req *http.Request) bool { + return false +} + type HTTPTransport struct { logger logr.Logger telemetry *unitel.Telemetry transport http.RoundTripper - forwardTrace bool + tracePropagator TracePropagator tracedRequestHeaders []string tracedResponseHeaders []string @@ -37,13 +53,13 @@ type HTTPTransport struct { tracer trace.Tracer } -func NewTracedTransport(t *unitel.Telemetry, inner http.RoundTripper, forwardTrace bool, tracedRequestHeaders []string, tracedResponseHeaders []string, opts ...HTTPTransportOpt) *HTTPTransport { +func NewTracedTransport(t *unitel.Telemetry, inner http.RoundTripper, tracedRequestHeaders []string, tracedResponseHeaders []string, opts ...HTTPTransportOpt) *HTTPTransport { transport := &HTTPTransport{ logger: logr.Discard(), telemetry: t, transport: inner, - forwardTrace: forwardTrace, + tracePropagator: PropagateNoTraces, tracedRequestHeaders: tracedRequestHeaders, tracedResponseHeaders: tracedResponseHeaders, @@ -69,7 +85,7 @@ func (t *HTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { ctx := s.Context() req = req.WithContext(ctx) - if t.forwardTrace { + if t.tracePropagator(req) { t.telemetry.InjectIntoHeaders(ctx, req.Header) }