diff --git a/helper.go b/helper.go index 19e2399..17dd7c1 100644 --- a/helper.go +++ b/helper.go @@ -32,7 +32,7 @@ func AddBreadcrumbToContext(c context.Context, breadcrumb *sentry.Breadcrumb) { } } -func (o *Telemetry) Trace(ctx context.Context, op string, options []ConfigureSpanStartFunc, fn func(context.Context)) { +func (o *Telemetry) Trace(ctx context.Context, op string, options []SpanStartOpt, fn func(context.Context)) { tx := o.StartSpan(ctx, op, op, options...) defer tx.End() diff --git a/tracing.go b/tracing.go index 27b533f..fdd3d3c 100644 --- a/tracing.go +++ b/tracing.go @@ -23,9 +23,9 @@ var ( tracerContextKey = contextKey{"tracer"} ) -type ConfigureSpanStartFunc = func(context.Context) (context.Context, []trace.SpanStartOption, []sentry.SpanOption) +type SpanStartOpt = func(context.Context) (context.Context, []trace.SpanStartOption, []sentry.SpanOption) -func (t *Telemetry) StartSpan(ctx context.Context, operation, name string, cfgs ...ConfigureSpanStartFunc) *Span { +func (t *Telemetry) StartSpan(ctx context.Context, operation, name string, cfgs ...SpanStartOpt) *Span { otelStartOpts := make([]trace.SpanStartOption, 0) sentryStartOpts := []sentry.SpanOption{sentry.WithTransactionName(name), sentry.WithDescription(name)} @@ -55,7 +55,7 @@ func (t *Telemetry) StartSpan(ctx context.Context, operation, name string, cfgs } } -func WithOtelOptions(opts ...trace.SpanStartOption) ConfigureSpanStartFunc { +func WithOtelOptions(opts ...trace.SpanStartOption) SpanStartOpt { return func(ctx context.Context) (context.Context, []trace.SpanStartOption, []sentry.SpanOption) { return ctx, opts, []sentry.SpanOption{} } @@ -85,7 +85,7 @@ func (t *Telemetry) InjectIntoMap(ctx context.Context, m map[string]string) { } } -func (t *Telemetry) ContinueFromHeaders(h http.Header) ConfigureSpanStartFunc { +func (t *Telemetry) ContinueFromHeaders(h http.Header) SpanStartOpt { return func(ctx context.Context) (context.Context, []trace.SpanStartOption, []sentry.SpanOption) { ctx = t.Propagator.Extract(ctx, propagation.HeaderCarrier(h)) @@ -98,7 +98,7 @@ func (t *Telemetry) ContinueFromHeaders(h http.Header) ConfigureSpanStartFunc { } } -func (t *Telemetry) ContinueFromMap(m map[string]string) ConfigureSpanStartFunc { +func (t *Telemetry) ContinueFromMap(m map[string]string) SpanStartOpt { return func(ctx context.Context) (context.Context, []trace.SpanStartOption, []sentry.SpanOption) { ctx = t.Propagator.Extract(ctx, propagation.MapCarrier(m)) @@ -114,7 +114,7 @@ func (t *Telemetry) ContinueFromMap(m map[string]string) ConfigureSpanStartFunc } } -func WithOtelTracer(tracer trace.Tracer) ConfigureSpanStartFunc { +func WithOtelTracer(tracer trace.Tracer) SpanStartOpt { return func(ctx context.Context) (context.Context, []trace.SpanStartOption, []sentry.SpanOption) { return context.WithValue(ctx, tracerContextKey, tracer), []trace.SpanStartOption{}, []sentry.SpanOption{} } diff --git a/unitelhttp/fiber_middleware.go b/unitelhttp/fiber_middleware.go index bbc7031..d5711e3 100644 --- a/unitelhttp/fiber_middleware.go +++ b/unitelhttp/fiber_middleware.go @@ -16,7 +16,6 @@ import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/metric" - "go.opentelemetry.io/otel/propagation" semconv "go.opentelemetry.io/otel/semconv/v1.26.0" "go.opentelemetry.io/otel/trace" ) @@ -31,6 +30,7 @@ type FiberMiddlewareConfig struct { TraceResponseHeaders []string IgnoredRoutes []string Logger logr.Logger + TracePropagator TracePropagator } var fiberMiddlewareConfigDefault = FiberMiddlewareConfig{ @@ -41,6 +41,7 @@ var fiberMiddlewareConfigDefault = FiberMiddlewareConfig{ TraceResponseHeaders: []string{}, IgnoredRoutes: []string{}, Logger: logr.Discard(), + TracePropagator: PropagateNoTraces, } func FiberMiddleware(t *unitel.Telemetry, config ...FiberMiddlewareConfig) fiber.Handler { @@ -60,6 +61,9 @@ func FiberMiddleware(t *unitel.Telemetry, config ...FiberMiddlewareConfig) fiber if cfg.IgnoredRoutes == nil { cfg.IgnoredRoutes = []string{} } + if cfg.TracePropagator == nil { + cfg.TracePropagator = PropagateNoTraces + } meter := t.MeterProvider.Meter(fiberMwClientID, metric.WithInstrumentationVersion(unitelutils.Version)) tracer := t.TracerProvider.Tracer(fiberMwClientID, trace.WithInstrumentationVersion(unitelutils.Version)) @@ -127,27 +131,21 @@ func FiberMiddleware(t *unitel.Telemetry, config ...FiberMiddlewareConfig) fiber return err } - // FIXME: Only extract the headers from trusted servers - ctx = t.Propagator.Extract(ctx, propagation.HeaderCarrier(stdRequest.Header)) - - hub := sentry.CurrentHub().Clone() - if client := hub.Client(); client != nil { - client.SetSDKIdentifier(fiberMwClientID) + opts := []unitel.SpanStartOpt{ + unitel.WithOtelOptions(trace.WithSpanKind(trace.SpanKindServer)), + unitel.WithOtelTracer(tracer), + } + shouldPropagate := cfg.TracePropagator(&stdRequest) + if shouldPropagate { + opts = append(opts, t.ContinueFromHeaders(stdRequest.Header)) } - - scope := hub.Scope() - scope.SetRequest(&stdRequest) - scope.SetRequestBody(utils.CopyBytes(c.Body())) - ctx = sentry.SetHubOnContext(ctx, hub) description := fmt.Sprintf("%s %s", c.Method(), c.Path()) span := t.StartSpan( ctx, "http.server", description, - unitel.WithOtelOptions(trace.WithSpanKind(trace.SpanKindServer)), - unitel.WithOtelTracer(tracer), - t.ContinueFromHeaders(stdRequest.Header), + opts..., ) defer func() { if err := recover(); err != nil { @@ -168,11 +166,17 @@ func FiberMiddleware(t *unitel.Telemetry, config ...FiberMiddlewareConfig) fiber defer span.End() defer func() { - h := propagation.HeaderCarrier{} - t.Propagator.Inject(ctx, h) + if !shouldPropagate { + return + } - for _, k := range h.Keys() { - c.Set(k, h.Get(k)) + h := http.Header{} + t.InjectIntoHeaders(ctx, h) + + for k, v := range h { + for _, vv := range v { + c.Response().Header.Add(k, vv) + } } }() @@ -183,6 +187,17 @@ func FiberMiddleware(t *unitel.Telemetry, config ...FiberMiddlewareConfig) fiber } } ctx = span.Context() + + hub := sentry.CurrentHub().Clone() + if client := hub.Client(); client != nil { + client.SetSDKIdentifier(fiberMwClientID) + } + + scope := hub.Scope() + scope.SetRequest(&stdRequest) + scope.SetRequestBody(utils.CopyBytes(c.Body())) + ctx = sentry.SetHubOnContext(ctx, hub) + c.SetUserContext(ctx) var err error = nil