unitel/unitelhttp/fiber_middleware.go

301 lines
8.2 KiB
Go

package unitelhttp
import (
"fmt"
"net/http"
"slices"
"time"
"git.devminer.xyz/devminer/unitel"
"git.devminer.xyz/devminer/unitel/unitelutils"
"github.com/getsentry/sentry-go"
"github.com/go-logr/logr"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/valyala/fasthttp/fasthttpadaptor"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/metric"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
"go.opentelemetry.io/otel/trace"
)
const fiberMwClientID = unitelutils.Base + "#FiberMiddleware@" + unitelutils.Version
type FiberMiddlewareConfig struct {
Repanic bool
WaitForDelivery bool
Timeout time.Duration
TraceRequestHeaders []string
TraceResponseHeaders []string
IgnoredRoutes []string
Logger logr.Logger
TracePropagator TracePropagator
}
var fiberMiddlewareConfigDefault = FiberMiddlewareConfig{
Repanic: false,
WaitForDelivery: false,
Timeout: time.Second * 2,
TraceRequestHeaders: []string{},
TraceResponseHeaders: []string{},
IgnoredRoutes: []string{},
Logger: logr.Discard(),
TracePropagator: PropagateNoTraces,
}
func FiberMiddleware(t *unitel.Telemetry, config ...FiberMiddlewareConfig) fiber.Handler {
cfg := fiberMiddlewareConfigDefault
if len(config) > 0 {
cfg = config[0]
}
if cfg.Timeout == 0 {
cfg.Timeout = time.Second * 2
}
if cfg.TraceRequestHeaders == nil {
cfg.TraceRequestHeaders = []string{}
}
if cfg.TraceResponseHeaders == nil {
cfg.TraceResponseHeaders = []string{}
}
if cfg.IgnoredRoutes == nil {
cfg.IgnoredRoutes = []string{}
}
if cfg.TracePropagator == nil {
cfg.TracePropagator = PropagateNoTraces
}
meter := t.MeterProvider.Meter(fiberMwClientID, metric.WithInstrumentationVersion(unitelutils.Version))
tracer := t.TracerProvider.Tracer(fiberMwClientID, trace.WithInstrumentationVersion(unitelutils.Version))
mDuration, err := meter.Float64Histogram(
"http.server.duration",
metric.WithDescription("HTTP request response times"),
metric.WithUnit("ms"),
)
if err != nil {
cfg.Logger.Error(err, "Failed to create request duration histogram")
panic(err)
}
mActiveRequests, err := meter.Int64UpDownCounter(
"http.server.active_requests",
metric.WithDescription("Number of in-flight HTTP requests"),
metric.WithUnit("1"),
)
if err != nil {
cfg.Logger.Error(err, "Failed to create active requests counter")
panic(err)
}
mRequestSize, err := meter.Int64Histogram(
"http.server.request.size",
metric.WithUnit("By"),
metric.WithDescription("HTTP request body sizes"),
)
if err != nil {
cfg.Logger.Error(err, "Failed to create request body size counter")
panic(err)
}
mResponseSize, err := meter.Int64Histogram(
"http.server.response.size",
metric.WithUnit("By"),
metric.WithDescription("HTTP response body sizes"),
)
if err != nil {
cfg.Logger.Error(err, "Failed to create response body size counter")
panic(err)
}
return func(c *fiber.Ctx) error {
ctx := c.UserContext()
ctx = unitel.SetOnContext(t, ctx)
c.SetUserContext(ctx)
// Skip ignored routes (/ping for example)
if slices.Contains(cfg.IgnoredRoutes, c.Path()) {
return c.Next()
}
start := time.Now()
requestMetricsAttrs := httpServerTraceAttributesFromRequest(c)
mActiveRequests.Add(ctx, 1, metric.WithAttributes(requestMetricsAttrs...))
responseMetricAttrs := make([]attribute.KeyValue, len(requestMetricsAttrs))
copy(responseMetricAttrs, requestMetricsAttrs)
var stdRequest http.Request
if err := fasthttpadaptor.ConvertRequest(c.Context(), &stdRequest, true); err != nil {
return err
}
opts := []unitel.SpanStartOpt{
unitel.WithOtelOptions(trace.WithSpanKind(trace.SpanKindServer)),
unitel.WithOtelTracer(tracer),
}
shouldPropagate := cfg.TracePropagator(&stdRequest)
if shouldPropagate {
opts = append(opts, t.ContinueFromHeaders(stdRequest.Header))
}
description := fmt.Sprintf("%s %s", c.Method(), c.Path())
span := t.StartSpan(
ctx,
"http.server",
description,
opts...,
)
defer func() {
if err := recover(); err != nil {
cfg.Logger.WithCallDepth(0).V(-2).Info("panic()ed: ", "err", err)
timeout := (*time.Duration)(nil)
if cfg.WaitForDelivery {
timeout = &cfg.Timeout
}
span.Recover(ctx, fmt.Errorf("%v", err), timeout)
if cfg.Repanic {
panic(err)
}
}
}()
defer span.End()
defer func() {
if !shouldPropagate {
return
}
h := http.Header{}
t.InjectIntoHeaders(ctx, h)
for k, v := range h {
for _, vv := range v {
c.Response().Header.Add(k, vv)
}
}
}()
span.AddAttributes(httpServerTraceAttributesFromRequest(c)...)
for _, k := range cfg.TraceRequestHeaders {
if h := c.Get(k); h != "" {
span.AddAttributes(attribute.String(fmt.Sprintf("http.request.header.%s", k), h))
}
}
ctx = span.Context()
hub := sentry.CurrentHub().Clone()
if client := hub.Client(); client != nil {
client.SetSDKIdentifier(fiberMwClientID)
}
scope := hub.Scope()
scope.SetRequest(&stdRequest)
scope.SetRequestBody(utils.CopyBytes(c.Body()))
ctx = sentry.SetHubOnContext(ctx, hub)
c.SetUserContext(ctx)
var err error = nil
if err = c.Next(); err != nil {
shouldReport := false
switch err := err.(type) {
case *fiber.Error:
shouldReport = err.Code >= http.StatusInternalServerError
default:
shouldReport = true
}
if shouldReport {
span.CaptureError(err)
}
err = c.App().Config().ErrorHandler(c, err)
}
defer func() {
responseAttrs := []attribute.KeyValue{
semconv.HTTPResponseStatusCode(c.Response().StatusCode()),
semconv.HTTPRouteKey.String(c.Route().Path),
}
requestSize := int64(len(c.Request().Body()))
responseSize := int64(len(c.Response().Body()))
responseMetricAttrs = append(responseMetricAttrs, responseAttrs...)
mActiveRequests.Add(c.Context(), -1, metric.WithAttributes(requestMetricsAttrs...))
mDuration.Record(ctx, float64(time.Since(start).Milliseconds()), metric.WithAttributes(responseMetricAttrs...))
mRequestSize.Record(ctx, requestSize, metric.WithAttributes(responseMetricAttrs...))
mResponseSize.Record(ctx, responseSize, metric.WithAttributes(responseMetricAttrs...))
span.
AddAttributes(responseAttrs...).
AddAttributes(attribute.Int64("http.request.headers.content-length", requestSize)).
SetName(c.Route().Path).
SetStatus(httpStatusToSpanStatus(c.Response().StatusCode(), true), "")
for _, k := range cfg.TraceResponseHeaders {
if h := c.GetRespHeader(k); h != "" {
span.AddAttributes(attribute.String(fmt.Sprintf("http.response.header.%s", k), h))
}
}
}()
return err
}
}
func httpStatusToSpanStatus(code int, isServer bool) unitel.SpanStatus {
sentryStatus := sentry.HTTPtoSpanStatus(code)
if code < http.StatusBadRequest {
return unitel.SpanStatus{codes.Ok, sentryStatus}
}
if code < http.StatusInternalServerError {
// For HTTP status codes in the 4xx range span status MUST be left unset
// in case of SpanKind.SERVER and MUST be set to Error in case of SpanKind.CLIENT.
if isServer {
return unitel.SpanStatus{codes.Unset, sentryStatus}
}
return unitel.SpanStatus{codes.Error, sentryStatus}
}
return unitel.SpanStatus{codes.Error, sentryStatus}
}
func httpFlavorAttribute(c *fiber.Ctx) attribute.KeyValue {
if c.Request().Header.IsHTTP11() {
return semconv.NetworkProtocolName("HTTP/1.1")
}
return semconv.NetworkProtocolName("HTTP/1.0")
}
func httpServerTraceAttributesFromRequest(c *fiber.Ctx) []attribute.KeyValue {
attrs := []attribute.KeyValue{
httpFlavorAttribute(c),
semconv.HTTPRequestMethodKey.String(utils.CopyString(c.Method())),
attribute.Int("http.response.header.content-length", c.Request().Header.ContentLength()),
semconv.URLScheme(utils.CopyString(c.Protocol())),
semconv.URLPath(utils.CopyString(string(c.Request().RequestURI()))),
semconv.URLFull(utils.CopyString(c.OriginalURL())),
semconv.ServerAddress(utils.CopyString(c.Hostname())),
semconv.UserAgentOriginalKey.String(utils.CopyString(string(c.Request().Header.UserAgent()))),
semconv.NetworkTransportTCP,
}
clientIP := c.IP()
if len(clientIP) > 0 {
attrs = append(attrs, semconv.ClientAddressKey.String(utils.CopyString(clientIP)))
}
return attrs
}