package unitelsql import ( "context" "database/sql/driver" "fmt" "strings" "time" "git.devminer.xyz/devminer/unitel" "git.devminer.xyz/devminer/unitel/unitelutils" "github.com/go-logr/logr" "github.com/qustavo/sqlhooks/v2" "github.com/rs/zerolog/log" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" semconv "go.opentelemetry.io/otel/semconv/v1.26.0" "go.opentelemetry.io/otel/trace" ) // TODO: Port to github.com/loghole/dbhook for proper error handling const sqlClientID = unitelutils.Base + "#TracedSQL@" + unitelutils.Version type dbCtxKeyT struct{} var dbCtxKey = dbCtxKeyT{} type dbCtxVal struct { start time.Time attrs []attribute.KeyValue } type tracedSQLHooks struct { telemetry *unitel.Telemetry logger logr.Logger dbType string mDuration metric.Float64Histogram tracer trace.Tracer } type TracedSQLOpt func(t *tracedSQLHooks) func WithLogger(l logr.Logger) TracedSQLOpt { return func(t *tracedSQLHooks) { t.logger = l } } func NewTracedSQL(t *unitel.Telemetry, driver driver.Driver, dbType string, opts ...TracedSQLOpt) driver.Driver { meter := t.MeterProvider.Meter(sqlClientID, metric.WithInstrumentationVersion(unitelutils.Version)) mDuration, err := meter.Float64Histogram( "db.client.operation.duration", metric.WithDescription("Database query response times"), metric.WithUnit("ms"), metric.WithExplicitBucketBoundaries(0.001, 0.005, 0.01, 0.05, 0.1, 0.5, 1, 5, 10), ) if err != nil { log.Fatal().Err(err).Msg("failed to create request duration histogram") } traced := &tracedSQLHooks{ dbType: dbType, telemetry: t, mDuration: mDuration, tracer: t.TracerProvider.Tracer(sqlClientID, trace.WithInstrumentationVersion(unitelutils.Version)), } for _, o := range opts { o(traced) } return sqlhooks.Wrap(driver, traced) } func (h *tracedSQLHooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { cleanedQuery := strings.ReplaceAll(query, "\n", " ") cleanedQuery = strings.ReplaceAll(cleanedQuery, "\t", " ") cleanedQuery = strings.ReplaceAll(cleanedQuery, " ", " ") cleanedQuery = strings.TrimSpace(cleanedQuery) operation := strings.Split(query, " ")[0] attrs := []attribute.KeyValue{ semconv.DBSystemKey.String(h.dbType), semconv.DBOperationName(operation), /* Sentry */ attribute.String("db.statement", cleanedQuery), attribute.StringSlice("db.params", formatArgs(args)), /* OpenTelemetry */ attribute.String("db.query.text", cleanedQuery), } for i, arg := range args { attrs = append(attrs, attribute.String(fmt.Sprintf("db.query.parameter.%d", i), formatArg(arg))) } s := h.telemetry.StartSpan(ctx, "db.sql.query", cleanedQuery, unitel.WithOtelTracer(h.tracer)) s.AddAttributes(attrs...) opts := make([]any, 2*len(args)) for i, arg := range args { opts = append(opts, fmt.Sprintf("arg%d", i), arg) } h.logger.Info(cleanedQuery, opts...) s.CaptureBreadcrumb(unitel.SeverityDebug, unitel.BreadcrumbTypeQuery, "started", query, map[string]any{"args": args}). End() return context.WithValue(s.Context(), dbCtxKey, dbCtxVal{ start: time.Now(), attrs: attrs, }), nil } func (h *tracedSQLHooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { if val, ok := ctx.Value(dbCtxKey).(dbCtxVal); ok { attrs := val.attrs h.mDuration.Record(ctx, float64(time.Since(val.start).Milliseconds()), metric.WithAttributes(attrs...)) } if s := unitel.SpanFromContext(ctx); s != nil { s.CaptureBreadcrumb(unitel.SeverityDebug, unitel.BreadcrumbTypeQuery, "finished", query, map[string]any{"args": args}). End() } return ctx, nil } func formatArgs(args []interface{}) []string { formattedArgs := make([]string, len(args)) for i, arg := range args { formattedArgs[i] = formatArg(arg) } return formattedArgs } func formatArg(arg interface{}) string { switch v := arg.(type) { case int: return fmt.Sprint(v) case int64: return fmt.Sprint(v) case float64: return fmt.Sprint(v) case string: return v case []byte: return string(v) case bool: return fmt.Sprint(v) case fmt.Stringer: return v.String() default: return fmt.Sprintf("%+v", v) } }