unitel/pgx_tracing.go

104 lines
2.4 KiB
Go
Raw Normal View History

2024-07-23 17:46:15 +02:00
package unitel
import (
"context"
"database/sql"
"database/sql/driver"
2024-07-23 17:46:15 +02:00
"fmt"
"strings"
pgx "github.com/jackc/pgx/v5/stdlib"
"github.com/qustavo/sqlhooks/v2"
"github.com/rs/zerolog/log"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
const (
pgxClientID = libBase + "#TracedPGx"
TracedPGxDriverName = "pgx-traced"
)
type tracedPgxHooks struct {
t *Telemetry
printQueries bool
tracer trace.Tracer
}
func (h *tracedPgxHooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
cleanedQuery := strings.ReplaceAll(query, "\n", " ")
cleanedQuery = strings.ReplaceAll(cleanedQuery, "\t", " ")
cleanedQuery = strings.ReplaceAll(cleanedQuery, " ", " ")
cleanedQuery = strings.TrimSpace(cleanedQuery)
s := h.t.StartSpan(ctx, "db.sql.query", query)
s.AddAttributes(
attribute.String("db.system", "postgres"),
attribute.String("db.statement", cleanedQuery),
attribute.StringSlice("db.params", formatArgs(args)),
)
if h.printQueries {
l := log.Trace()
for i, arg := range args {
l = l.Interface(fmt.Sprintf("arg%d", i), arg)
}
l.Msg(cleanedQuery)
}
return s.Context(), nil
}
func (h *tracedPgxHooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
if s := SpanFromContext(ctx); s != nil {
s.End()
}
return ctx, nil
}
func (t *Telemetry) RegisterTracedPGx(printQueries bool) {
sql.Register(TracedPGxDriverName, t.TracedPGx(printQueries))
log.Debug().Msgf("Registered %s driver", TracedPGxDriverName)
}
func (t *Telemetry) TracedPGx(printQueries bool) driver.Driver {
2024-07-23 17:46:15 +02:00
tracer := t.tracerProvider.Tracer(pgxClientID, trace.WithInstrumentationVersion(libVersion))
return sqlhooks.Wrap(&pgx.Driver{}, &tracedPgxHooks{
2024-07-23 17:46:15 +02:00
printQueries: printQueries,
t: t,
tracer: tracer,
})
2024-07-23 17:46:15 +02:00
}
func formatArgs(args []interface{}) []string {
formattedArgs := make([]string, len(args))
for i, arg := range args {
switch v := arg.(type) {
case int:
formattedArgs[i] = fmt.Sprint(v)
case int64:
formattedArgs[i] = fmt.Sprint(v)
case float64:
formattedArgs[i] = fmt.Sprint(v)
case string:
formattedArgs[i] = v
case []byte:
formattedArgs[i] = string(v)
case bool:
formattedArgs[i] = fmt.Sprint(v)
case fmt.Stringer:
formattedArgs[i] = v.String()
default:
formattedArgs[i] = fmt.Sprintf("%+v", v)
}
}
return formattedArgs
}