unitel/fiber_middleware.go

281 lines
7.8 KiB
Go
Raw Normal View History

2024-07-23 17:46:15 +02:00
package unitel
import (
"fmt"
"net/http"
"slices"
"time"
"github.com/getsentry/sentry-go"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/utils"
"github.com/rs/zerolog/log"
"github.com/valyala/fasthttp/fasthttpadaptor"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
"go.opentelemetry.io/otel/trace"
)
const fiberMwClientID = libBase + "#FiberMiddleware"
type FiberMiddlewareConfig struct {
Repanic bool
WaitForDelivery bool
Timeout time.Duration
TraceRequestHeaders []string
TraceResponseHeaders []string
IgnoredRoutes []string
}
var fiberMiddlewareConfigDefault = FiberMiddlewareConfig{
Repanic: false,
WaitForDelivery: false,
Timeout: time.Second * 2,
TraceRequestHeaders: []string{},
TraceResponseHeaders: []string{},
IgnoredRoutes: []string{},
}
func newFiberMiddlewareTracer(tp trace.TracerProvider) trace.Tracer {
return tp.Tracer(fiberMwClientID, trace.WithInstrumentationVersion(libVersion))
}
func newFiberMiddlewareMeter(mp metric.MeterProvider) metric.Meter {
return mp.Meter(fiberMwClientID, metric.WithInstrumentationVersion(libVersion))
}
func (t *Telemetry) FiberMiddleware(config ...FiberMiddlewareConfig) fiber.Handler {
cfg := fiberMiddlewareConfigDefault
if len(config) > 0 {
cfg = config[0]
}
if cfg.Timeout == 0 {
cfg.Timeout = time.Second * 2
}
if cfg.TraceRequestHeaders == nil {
cfg.TraceRequestHeaders = []string{}
}
if cfg.TraceResponseHeaders == nil {
cfg.TraceResponseHeaders = []string{}
}
if cfg.IgnoredRoutes == nil {
cfg.IgnoredRoutes = []string{}
}
meter := newFiberMiddlewareMeter(t.meterProvider)
tracer := newFiberMiddlewareTracer(t.tracerProvider)
mDuration, err := meter.Float64Histogram(
"http.server.duration",
metric.WithDescription("HTTP request response times"),
metric.WithUnit("ms"),
)
if err != nil {
log.Fatal().Err(err).Msg("failed to create request duration histogram")
}
mActiveRequests, err := meter.Int64UpDownCounter(
"http.server.active_requests",
metric.WithDescription("Number of in-flight HTTP requests"),
metric.WithUnit("1"),
)
if err != nil {
log.Fatal().Err(err).Msg("failed to create active requests counter")
}
mRequestSize, err := meter.Int64Histogram(
"http.server.request.size",
metric.WithUnit("By"),
metric.WithDescription("HTTP request sizes"),
)
if err != nil {
log.Fatal().Err(err).Msg("failed to create request size histogram")
}
mResponseSize, err := meter.Int64Histogram(
"http.server.response.size",
metric.WithUnit("By"),
metric.WithDescription("HTTP response sizes"),
)
if err != nil {
log.Fatal().Err(err).Msg("failed to create response size histogram")
}
return func(c *fiber.Ctx) error {
// Skip ignored routes (/ping for example)
if slices.Contains(cfg.IgnoredRoutes, c.Path()) {
return c.Next()
}
start := time.Now()
requestMetricsAttrs := httpServerTraceAttributesFromRequest(c)
mActiveRequests.Add(c.Context(), 1, metric.WithAttributes(requestMetricsAttrs...))
responseMetricAttrs := make([]attribute.KeyValue, len(requestMetricsAttrs))
copy(responseMetricAttrs, requestMetricsAttrs)
var stdRequest http.Request
if err := fasthttpadaptor.ConvertRequest(c.Context(), &stdRequest, true); err != nil {
return err
}
ctx := t.propagator.Extract(c.UserContext(), propagation.HeaderCarrier(stdRequest.Header))
hub := sentry.CurrentHub().Clone()
if client := hub.Client(); client != nil {
client.SetSDKIdentifier(fiberMwClientID)
}
scope := hub.Scope()
scope.SetRequest(&stdRequest)
scope.SetRequestBody(utils.CopyBytes(c.Body()))
ctx = sentry.SetHubOnContext(ctx, hub)
description := fmt.Sprintf("%s %s", c.Method(), c.Path())
span := t.StartSpan(
ctx,
"http.server",
description,
WithOtelOptions(trace.WithSpanKind(trace.SpanKindServer)),
WithOtelTracer(tracer),
t.ContinueFromRequest(&stdRequest),
)
defer func() {
// TODO: Report panics properly
if err := recover(); err != nil {
timeout := (*time.Duration)(nil)
if cfg.WaitForDelivery {
timeout = &cfg.Timeout
}
span.Recover(ctx, fmt.Errorf("%v", err), timeout)
if cfg.Repanic {
panic(err)
}
}
}()
defer span.End()
defer func() {
h := propagation.HeaderCarrier{}
t.propagator.Inject(ctx, h)
for _, k := range h.Keys() {
c.Set(k, h.Get(k))
}
}()
span.AddAttributes(httpServerTraceAttributesFromRequest(c)...)
for _, k := range cfg.TraceRequestHeaders {
if h := c.Get(k); h != "" {
span.AddAttributes(attribute.String(fmt.Sprintf("http.request.header.%s", k), h))
}
}
ctx = span.Context()
c.SetUserContext(ctx)
var err error = nil
if err = c.Next(); err != nil {
shouldReport := false
switch err := err.(type) {
case *fiber.Error:
shouldReport = err.Code >= http.StatusInternalServerError
default:
shouldReport = true
}
if shouldReport {
span.CaptureError(err)
}
err = c.App().Config().ErrorHandler(c, err)
}
defer func() {
responseAttrs := []attribute.KeyValue{
semconv.HTTPResponseStatusCode(c.Response().StatusCode()),
semconv.HTTPRouteKey.String(c.Route().Path),
}
requestSize := int64(len(c.Request().Body()))
responseSize := int64(len(c.Response().Body()))
responseMetricAttrs = append(responseMetricAttrs, responseAttrs...)
mActiveRequests.Add(c.Context(), -1, metric.WithAttributes(requestMetricsAttrs...))
mDuration.Record(ctx, float64(time.Since(start).Milliseconds()), metric.WithAttributes(responseMetricAttrs...))
mRequestSize.Record(ctx, requestSize, metric.WithAttributes(responseMetricAttrs...))
mResponseSize.Record(ctx, responseSize, metric.WithAttributes(responseMetricAttrs...))
span.
AddAttributes(responseAttrs...).
AddAttributes(attribute.Int64("http.request.headers.content-length", requestSize)).
SetName(c.Route().Path).
SetStatus(httpStatusToSpanStatus(c.Response().StatusCode(), true), "")
for _, k := range cfg.TraceResponseHeaders {
if h := c.GetRespHeader(k); h != "" {
span.AddAttributes(attribute.String(fmt.Sprintf("http.response.header.%s", k), h))
}
}
}()
return err
}
}
func httpStatusToSpanStatus(code int, isServer bool) SpanStatus {
sentryStatus := sentry.HTTPtoSpanStatus(code)
if code < http.StatusBadRequest {
return SpanStatus{codes.Ok, sentryStatus}
}
if code < http.StatusInternalServerError {
// For HTTP status codes in the 4xx range span status MUST be left unset
// in case of SpanKind.SERVER and MUST be set to Error in case of SpanKind.CLIENT.
if isServer {
return SpanStatus{codes.Unset, sentryStatus}
}
return SpanStatus{codes.Error, sentryStatus}
}
return SpanStatus{codes.Error, sentryStatus}
}
func httpFlavorAttribute(c *fiber.Ctx) attribute.KeyValue {
if c.Request().Header.IsHTTP11() {
return semconv.NetworkProtocolName("HTTP/1.1")
}
return semconv.NetworkProtocolName("HTTP/1.0")
}
func httpServerTraceAttributesFromRequest(c *fiber.Ctx) []attribute.KeyValue {
attrs := []attribute.KeyValue{
httpFlavorAttribute(c),
semconv.HTTPRequestMethodKey.String(utils.CopyString(c.Method())),
attribute.Int("http.response.header.content-length", c.Request().Header.ContentLength()),
semconv.URLScheme(utils.CopyString(c.Protocol())),
semconv.URLPath(utils.CopyString(string(c.Request().RequestURI()))),
semconv.URLFull(utils.CopyString(c.OriginalURL())),
semconv.ServerAddress(utils.CopyString(c.Hostname())),
semconv.UserAgentOriginalKey.String(utils.CopyString(string(c.Request().Header.UserAgent()))),
semconv.NetworkTransportTCP,
}
clientIP := c.IP()
if len(clientIP) > 0 {
attrs = append(attrs, semconv.ClientAddressKey.String(utils.CopyString(clientIP)))
}
return attrs
}