diff --git a/pgx_tracing.go b/sql_tracing.go similarity index 76% rename from pgx_tracing.go rename to sql_tracing.go index ecaa890..ba54209 100644 --- a/pgx_tracing.go +++ b/sql_tracing.go @@ -2,13 +2,11 @@ package unitel import ( "context" - "database/sql" "database/sql/driver" "fmt" "strings" "time" - pgx "github.com/jackc/pgx/v5/stdlib" "github.com/qustavo/sqlhooks/v2" "github.com/rs/zerolog/log" "go.opentelemetry.io/otel/attribute" @@ -18,13 +16,19 @@ import ( ) const ( - pgxClientID = libBase + "#TracedPGx" - TracedPGxDriverName = "pgx-traced" + sqlClientID = libBase + "#TracedSQL" ) -type dbCtxStart struct{} +var dbCtxKey = struct{}{} + +type dbCtxVal struct { + start time.Time + attrs []attribute.KeyValue +} type tracedPgxHooks struct { + dbType string + t *Telemetry printQueries bool tracer trace.Tracer @@ -41,8 +45,9 @@ func (h *tracedPgxHooks) Before(ctx context.Context, query string, args ...inter operation := strings.Split(query, " ")[0] attrs := []attribute.KeyValue{ - semconv.DBSystemPostgreSQL, + semconv.DBSystemKey.String(h.dbType), semconv.DBOperationName(operation), + /* Sentry */ attribute.String("db.statement", cleanedQuery), attribute.StringSlice("db.params", formatArgs(args)), @@ -67,7 +72,10 @@ func (h *tracedPgxHooks) Before(ctx context.Context, query string, args ...inter l.Msg(cleanedQuery) } - return context.WithValue(s.Context(), dbCtxStart{}, time.Now()), nil + return context.WithValue(s.Context(), dbCtxKey, dbCtxVal{ + start: time.Now(), + attrs: attrs, + }), nil } func (h *tracedPgxHooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) { @@ -75,22 +83,18 @@ func (h *tracedPgxHooks) After(ctx context.Context, query string, args ...interf s.End() } - if start, ok := ctx.Value(dbCtxStart{}).(time.Time); ok { - h.mDuration.Record(ctx, float64(time.Since(start).Milliseconds())) + if val, ok := ctx.Value(dbCtxKey).(dbCtxVal); ok { + attrs := val.attrs + + h.mDuration.Record(ctx, float64(time.Since(val.start).Milliseconds()), metric.WithAttributes(attrs...)) } return ctx, nil } -func (t *Telemetry) RegisterTracedPGx(printQueries bool) { - sql.Register(TracedPGxDriverName, t.TracedPGx(printQueries)) - - log.Debug().Msgf("Registered %s driver", TracedPGxDriverName) -} - -func (t *Telemetry) TracedPGx(printQueries bool) driver.Driver { - tracer := t.tracerProvider.Tracer(pgxClientID, trace.WithInstrumentationVersion(libVersion)) - meter := t.meterProvider.Meter(pgxClientID, metric.WithInstrumentationVersion(libVersion)) +func (t *Telemetry) TraceSQL(driver driver.Driver, dbType string, printQueries bool) driver.Driver { + tracer := t.tracerProvider.Tracer(sqlClientID, trace.WithInstrumentationVersion(libVersion)) + meter := t.meterProvider.Meter(sqlClientID, metric.WithInstrumentationVersion(libVersion)) mDuration, err := meter.Float64Histogram( "db.client.operation.duration", @@ -102,11 +106,14 @@ func (t *Telemetry) TracedPGx(printQueries bool) driver.Driver { log.Fatal().Err(err).Msg("failed to create request duration histogram") } - return sqlhooks.Wrap(&pgx.Driver{}, &tracedPgxHooks{ + return sqlhooks.Wrap(driver, &tracedPgxHooks{ + dbType: dbType, + printQueries: printQueries, t: t, tracer: tracer, - mDuration: mDuration, + + mDuration: mDuration, }) }