package unitel import ( "context" "database/sql" "database/sql/driver" "fmt" "strings" pgx "github.com/jackc/pgx/v5/stdlib" "github.com/qustavo/sqlhooks/v2" "github.com/rs/zerolog/log" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" ) const ( pgxClientID = libBase + "#TracedPGx" TracedPGxDriverName = "pgx-traced" ) type tracedPgxHooks struct { t *Telemetry printQueries bool tracer trace.Tracer } func (h *tracedPgxHooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) { cleanedQuery := strings.ReplaceAll(query, "\n", " ") cleanedQuery = strings.ReplaceAll(cleanedQuery, "\t", " ") cleanedQuery = strings.ReplaceAll(cleanedQuery, " ", " ") cleanedQuery = strings.TrimSpace(cleanedQuery) s := h.t.StartSpan(ctx, "db.sql.query", query) s.AddAttributes( attribute.String("db.system", "postgres"), attribute.String("db.statement", cleanedQuery), attribute.StringSlice("db.params", formatArgs(args)), ) if h.printQueries { l := log.Trace() for i, arg := range args { l = l.Interface(fmt.Sprintf("arg%d", i), arg) } l.Msg(cleanedQuery) } return s.Context(), nil } func (h *tracedPgxHooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { if s := SpanFromContext(ctx); s != nil { s.End() } return ctx, nil } func (t *Telemetry) RegisterTracedPGx(printQueries bool) { sql.Register(TracedPGxDriverName, t.TracedPGx(printQueries)) log.Debug().Msgf("Registered %s driver", TracedPGxDriverName) } func (t *Telemetry) TracedPGx(printQueries bool) driver.Driver { tracer := t.tracerProvider.Tracer(pgxClientID, trace.WithInstrumentationVersion(libVersion)) return sqlhooks.Wrap(&pgx.Driver{}, &tracedPgxHooks{ printQueries: printQueries, t: t, tracer: tracer, }) } func formatArgs(args []interface{}) []string { formattedArgs := make([]string, len(args)) for i, arg := range args { switch v := arg.(type) { case int: formattedArgs[i] = fmt.Sprint(v) case int64: formattedArgs[i] = fmt.Sprint(v) case float64: formattedArgs[i] = fmt.Sprint(v) case string: formattedArgs[i] = v case []byte: formattedArgs[i] = string(v) case bool: formattedArgs[i] = fmt.Sprint(v) case fmt.Stringer: formattedArgs[i] = v.String() default: formattedArgs[i] = fmt.Sprintf("%+v", v) } } return formattedArgs }