package unitelhttp import ( "fmt" "net/http" "runtime/debug" "slices" "time" "git.devminer.xyz/devminer/unitel" "git.devminer.xyz/devminer/unitel/unitelutils" "github.com/getsentry/sentry-go" "github.com/go-logr/logr" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/utils" "github.com/valyala/fasthttp/fasthttpadaptor" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/metric" semconv "go.opentelemetry.io/otel/semconv/v1.26.0" "go.opentelemetry.io/otel/trace" ) const fiberMwClientID = unitelutils.Base + "#FiberMiddleware@" + unitelutils.Version type FiberMiddlewareConfig struct { Repanic bool WaitForDelivery bool Timeout time.Duration TraceRequestHeaders []string TraceResponseHeaders []string IgnoredRoutes []string Logger logr.Logger TracePropagator TracePropagator } var fiberMiddlewareConfigDefault = FiberMiddlewareConfig{ Repanic: false, WaitForDelivery: false, Timeout: time.Second * 2, TraceRequestHeaders: []string{}, TraceResponseHeaders: []string{}, IgnoredRoutes: []string{}, Logger: logr.Discard(), TracePropagator: PropagateNoTraces, } func FiberMiddleware(t *unitel.Telemetry, config ...FiberMiddlewareConfig) fiber.Handler { cfg := fiberMiddlewareConfigDefault if len(config) > 0 { cfg = config[0] } if cfg.Timeout == 0 { cfg.Timeout = time.Second * 2 } if cfg.TraceRequestHeaders == nil { cfg.TraceRequestHeaders = []string{} } if cfg.TraceResponseHeaders == nil { cfg.TraceResponseHeaders = []string{} } if cfg.IgnoredRoutes == nil { cfg.IgnoredRoutes = []string{} } if cfg.TracePropagator == nil { cfg.TracePropagator = PropagateNoTraces } meter := t.MeterProvider.Meter(fiberMwClientID, metric.WithInstrumentationVersion(unitelutils.Version)) tracer := t.TracerProvider.Tracer(fiberMwClientID, trace.WithInstrumentationVersion(unitelutils.Version)) mDuration, err := meter.Float64Histogram( "http.server.duration", metric.WithDescription("HTTP request response times"), metric.WithUnit("ms"), ) if err != nil { cfg.Logger.Error(err, "Failed to create request duration histogram") panic(err) } mActiveRequests, err := meter.Int64UpDownCounter( "http.server.active_requests", metric.WithDescription("Number of in-flight HTTP requests"), metric.WithUnit("1"), ) if err != nil { cfg.Logger.Error(err, "Failed to create active requests counter") panic(err) } mRequestSize, err := meter.Int64Histogram( "http.server.request.size", metric.WithUnit("By"), metric.WithDescription("HTTP request body sizes"), ) if err != nil { cfg.Logger.Error(err, "Failed to create request body size counter") panic(err) } mResponseSize, err := meter.Int64Histogram( "http.server.response.size", metric.WithUnit("By"), metric.WithDescription("HTTP response body sizes"), ) if err != nil { cfg.Logger.Error(err, "Failed to create response body size counter") panic(err) } return func(c *fiber.Ctx) error { ctx := c.UserContext() ctx = unitel.SetOnContext(t, ctx) c.SetUserContext(ctx) // Skip ignored routes (/ping for example) if slices.Contains(cfg.IgnoredRoutes, c.Path()) { return c.Next() } start := time.Now() requestMetricsAttrs := httpServerTraceAttributesFromRequest(c) mActiveRequests.Add(ctx, 1, metric.WithAttributes(requestMetricsAttrs...)) responseMetricAttrs := make([]attribute.KeyValue, len(requestMetricsAttrs)) copy(responseMetricAttrs, requestMetricsAttrs) var stdRequest http.Request if err := fasthttpadaptor.ConvertRequest(c.Context(), &stdRequest, true); err != nil { return err } opts := []unitel.SpanStartOpt{ unitel.WithOtelOptions(trace.WithSpanKind(trace.SpanKindServer)), unitel.WithOtelTracer(tracer), } shouldPropagate := cfg.TracePropagator(&stdRequest) if shouldPropagate { opts = append(opts, t.ContinueFromHeaders(stdRequest.Header)) } description := fmt.Sprintf("%s %s", c.Method(), c.Path()) span := t.StartSpan( ctx, "http.server", description, opts..., ) defer func() { if err := recover(); err != nil { marker, l := cfg.Logger.WithCallStackHelper() marker() l.V(-2).Info("recover()ed from panic() in handler", "err", err) l.V(-2).Info("stack", "stack", string(debug.Stack())) timeout := (*time.Duration)(nil) if cfg.WaitForDelivery { timeout = &cfg.Timeout } span.Recover(ctx, fmt.Errorf("%v", err), timeout) if cfg.Repanic { panic(err) } } }() defer span.End() defer func() { if !shouldPropagate { return } h := http.Header{} t.InjectIntoHeaders(ctx, h) for k, v := range h { for _, vv := range v { c.Response().Header.Add(k, vv) } } }() span.AddAttributes(httpServerTraceAttributesFromRequest(c)...) for _, k := range cfg.TraceRequestHeaders { if h := c.Get(k); h != "" { span.AddAttributes(attribute.String(fmt.Sprintf("http.request.header.%s", k), h)) } } ctx = span.Context() hub := sentry.CurrentHub().Clone() if client := hub.Client(); client != nil { client.SetSDKIdentifier(fiberMwClientID) } scope := hub.Scope() scope.SetRequest(&stdRequest) scope.SetRequestBody(utils.CopyBytes(c.Body())) ctx = sentry.SetHubOnContext(ctx, hub) c.SetUserContext(ctx) var err error = nil if err = c.Next(); err != nil { shouldReport := false switch err := err.(type) { case *fiber.Error: shouldReport = err.Code >= http.StatusInternalServerError default: shouldReport = true } if shouldReport { span.CaptureError(err) } err = c.App().Config().ErrorHandler(c, err) } defer func() { responseAttrs := []attribute.KeyValue{ semconv.HTTPResponseStatusCode(c.Response().StatusCode()), semconv.HTTPRouteKey.String(c.Route().Path), } requestSize := int64(len(c.Request().Body())) responseSize := int64(len(c.Response().Body())) responseMetricAttrs = append(responseMetricAttrs, responseAttrs...) mActiveRequests.Add(c.Context(), -1, metric.WithAttributes(requestMetricsAttrs...)) mDuration.Record(ctx, float64(time.Since(start).Milliseconds()), metric.WithAttributes(responseMetricAttrs...)) mRequestSize.Record(ctx, requestSize, metric.WithAttributes(responseMetricAttrs...)) mResponseSize.Record(ctx, responseSize, metric.WithAttributes(responseMetricAttrs...)) span. AddAttributes(responseAttrs...). AddAttributes(attribute.Int64("http.request.headers.content-length", requestSize)). SetName(c.Route().Path). SetStatus(httpStatusToSpanStatus(c.Response().StatusCode(), true), "") for _, k := range cfg.TraceResponseHeaders { if h := c.GetRespHeader(k); h != "" { span.AddAttributes(attribute.String(fmt.Sprintf("http.response.header.%s", k), h)) } } }() return err } } func httpStatusToSpanStatus(code int, isServer bool) unitel.SpanStatus { sentryStatus := sentry.HTTPtoSpanStatus(code) if code < http.StatusBadRequest { return unitel.SpanStatus{codes.Ok, sentryStatus} } if code < http.StatusInternalServerError { // For HTTP status codes in the 4xx range span status MUST be left unset // in case of SpanKind.SERVER and MUST be set to Error in case of SpanKind.CLIENT. if isServer { return unitel.SpanStatus{codes.Unset, sentryStatus} } return unitel.SpanStatus{codes.Error, sentryStatus} } return unitel.SpanStatus{codes.Error, sentryStatus} } func httpFlavorAttribute(c *fiber.Ctx) attribute.KeyValue { if c.Request().Header.IsHTTP11() { return semconv.NetworkProtocolName("HTTP/1.1") } return semconv.NetworkProtocolName("HTTP/1.0") } func httpServerTraceAttributesFromRequest(c *fiber.Ctx) []attribute.KeyValue { attrs := []attribute.KeyValue{ httpFlavorAttribute(c), semconv.HTTPRequestMethodKey.String(utils.CopyString(c.Method())), attribute.Int("http.response.header.content-length", c.Request().Header.ContentLength()), semconv.URLScheme(utils.CopyString(c.Protocol())), semconv.URLPath(utils.CopyString(string(c.Request().RequestURI()))), semconv.URLFull(utils.CopyString(c.OriginalURL())), semconv.ServerAddress(utils.CopyString(c.Hostname())), semconv.UserAgentOriginalKey.String(utils.CopyString(string(c.Request().Header.UserAgent()))), semconv.NetworkTransportTCP, } clientIP := c.IP() if len(clientIP) > 0 { attrs = append(attrs, semconv.ClientAddressKey.String(utils.CopyString(clientIP))) } return attrs }