unitel/sql_tracing.go

149 lines
3.4 KiB
Go

package unitel
import (
"context"
"database/sql/driver"
"fmt"
"strings"
"time"
"github.com/qustavo/sqlhooks/v2"
"github.com/rs/zerolog/log"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
"go.opentelemetry.io/otel/trace"
)
const (
sqlClientID = libBase + "#TracedSQL"
)
var dbCtxKey = struct{}{}
type dbCtxVal struct {
start time.Time
attrs []attribute.KeyValue
}
type tracedSQLHooks struct {
dbType string
t *Telemetry
printQueries bool
tracer trace.Tracer
mDuration metric.Float64Histogram
}
func (h *tracedSQLHooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
cleanedQuery := strings.ReplaceAll(query, "\n", " ")
cleanedQuery = strings.ReplaceAll(cleanedQuery, "\t", " ")
cleanedQuery = strings.ReplaceAll(cleanedQuery, " ", " ")
cleanedQuery = strings.TrimSpace(cleanedQuery)
operation := strings.Split(query, " ")[0]
attrs := []attribute.KeyValue{
semconv.DBSystemKey.String(h.dbType),
semconv.DBOperationName(operation),
/* Sentry */
attribute.String("db.statement", cleanedQuery),
attribute.StringSlice("db.params", formatArgs(args)),
/* OpenTelemetry */
attribute.String("db.query.text", cleanedQuery),
}
for i, arg := range args {
attrs = append(attrs, attribute.String(fmt.Sprintf("db.query.parameter.%d", i), formatArg(arg)))
}
s := h.t.StartSpan(ctx, "db.sql.query", cleanedQuery, WithOtelTracer(h.tracer))
s.AddAttributes(attrs...)
if h.printQueries {
l := log.Trace()
for i, arg := range args {
l = l.Interface(fmt.Sprintf("arg%d", i), arg)
}
l.Msg(cleanedQuery)
}
return context.WithValue(s.Context(), dbCtxKey, dbCtxVal{
start: time.Now(),
attrs: attrs,
}), nil
}
func (h *tracedSQLHooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
if s := SpanFromContext(ctx); s != nil {
s.End()
}
if val, ok := ctx.Value(dbCtxKey).(dbCtxVal); ok {
attrs := val.attrs
h.mDuration.Record(ctx, float64(time.Since(val.start).Milliseconds()), metric.WithAttributes(attrs...))
}
return ctx, nil
}
func (t *Telemetry) TraceSQL(driver driver.Driver, dbType string, printQueries bool) driver.Driver {
tracer := t.tracerProvider.Tracer(sqlClientID, trace.WithInstrumentationVersion(libVersion))
meter := t.meterProvider.Meter(sqlClientID, metric.WithInstrumentationVersion(libVersion))
mDuration, err := meter.Float64Histogram(
"db.client.operation.duration",
metric.WithDescription("Database query response times"),
metric.WithUnit("ms"),
metric.WithExplicitBucketBoundaries(0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 5, 10),
)
if err != nil {
log.Fatal().Err(err).Msg("failed to create request duration histogram")
}
return sqlhooks.Wrap(driver, &tracedSQLHooks{
dbType: dbType,
printQueries: printQueries,
t: t,
tracer: tracer,
mDuration: mDuration,
})
}
func formatArgs(args []interface{}) []string {
formattedArgs := make([]string, len(args))
for i, arg := range args {
formattedArgs[i] = formatArg(arg)
}
return formattedArgs
}
func formatArg(arg interface{}) string {
switch v := arg.(type) {
case int:
return fmt.Sprint(v)
case int64:
return fmt.Sprint(v)
case float64:
return fmt.Sprint(v)
case string:
return v
case []byte:
return string(v)
case bool:
return fmt.Sprint(v)
case fmt.Stringer:
return v.String()
default:
return fmt.Sprintf("%+v", v)
}
}