changed bun to gorm in auth plugin

This commit is contained in:
anthdm 2024-06-23 10:53:56 +02:00
parent 6028c5d60d
commit 6055fc6426
12 changed files with 109 additions and 97 deletions

View file

@ -5,21 +5,25 @@ import (
"os" "os"
"github.com/anthdm/superkit/db" "github.com/anthdm/superkit/db"
"github.com/anthdm/superkit/kit"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/uptrace/bun" "gorm.io/driver/sqlite"
"github.com/uptrace/bun/dialect/sqlitedialect" "gorm.io/gorm"
"github.com/uptrace/bun/extra/bundebug"
) )
// I could not came up with a better naming for this. // By default this is a pre-configured Gorm DB instance.
// Ideally, app should export a global variable called "DB" // Change this type based on the database package of your likings.
// but this will cause imports cycles for plugins. var dbInstance *gorm.DB
var Query *bun.DB
// Get returns the instantiated DB instance.
func Get() *gorm.DB {
return dbInstance
}
func init() { func init() {
// Create a default *sql.DB exposed by the superkit/db package
// based on the given configuration.
config := db.Config{ config := db.Config{
Driver: os.Getenv("DB_DRIVER"), Driver: os.Getenv("DB_DRIVER"),
Name: os.Getenv("DB_NAME"), Name: os.Getenv("DB_NAME"),
@ -27,12 +31,30 @@ func init() {
User: os.Getenv("DB_USER"), User: os.Getenv("DB_USER"),
Host: os.Getenv("DB_HOST"), Host: os.Getenv("DB_HOST"),
} }
db, err := db.New(config) dbinst, err := db.NewSQL(config)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
Query = bun.NewDB(db, sqlitedialect.New()) // Based on the driver create the corresponding DB instance.
if kit.IsDevelopment() { // By default, the SuperKit boilerplate comes with a pre-configured
Query.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true))) // ORM called Gorm. https://gorm.io.
//
// You can change this to any other DB interaction tool
// of your liking. EG:
// - uptrace bun -> https://bun.uptrace.dev
// - SQLC -> https://github.com/sqlc-dev/sqlc
// - gojet -> https://github.com/go-jet/jet
switch config.Driver {
case db.DriverSqlite3:
dbInstance, err = gorm.Open(sqlite.New(sqlite.Config{
Conn: dbinst,
}))
case db.DriverMysql:
// ...
default:
log.Fatal("invalid driver:", config.Driver)
}
if err != nil {
log.Fatal(err)
} }
} }

View file

@ -5,9 +5,10 @@ create table if not exists users(
password_hash text not null, password_hash text not null,
first_name text not null, first_name text not null,
last_name text not null, last_name text not null,
email_verified_at timestamp with time zone, email_verified_at datetime,
created_at timestamp with time zone not null, created_at datetime not null,
updated_at timestamp with time zone not null updated_at datetime not null,
deleted_at datetime
); );
-- +goose Down -- +goose Down

View file

@ -5,9 +5,10 @@ create table if not exists sessions(
user_id integer not null references users, user_id integer not null references users,
ip_address text, ip_address text,
user_agent text, user_agent text,
expires_at timestamp with time zone not null, expires_at datetime not null,
last_login_at timestamp with time zone, created_at datetime not null,
created_at timestamp with time zone not null updated_at datetime not null,
deleted_at datetime
); );
-- +goose Down -- +goose Down

View file

@ -2,7 +2,7 @@ package types
// AuthUser represents an user that might be authenticated. // AuthUser represents an user that might be authenticated.
type AuthUser struct { type AuthUser struct {
ID int ID uint
Email string Email string
LoggedIn bool LoggedIn bool
} }

View file

@ -3,7 +3,7 @@ module AABBCCDD
go 1.22.4 go 1.22.4
// uncomment for local development on the superkit core. // uncomment for local development on the superkit core.
// replace github.com/anthdm/superkit => ../ replace github.com/anthdm/superkit => ../
require ( require (
github.com/a-h/templ v0.2.707 github.com/a-h/templ v0.2.707
@ -24,10 +24,13 @@ require (
github.com/gorilla/securecookie v1.1.2 // indirect github.com/gorilla/securecookie v1.1.2 // indirect
github.com/gorilla/sessions v1.3.0 // indirect github.com/gorilla/sessions v1.3.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
golang.org/x/sys v0.21.0 // indirect golang.org/x/sys v0.21.0 // indirect
gorm.io/driver/sqlite v1.5.6 // indirect
gorm.io/gorm v1.25.10 // indirect
) )

View file

@ -18,6 +18,8 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ= github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
@ -45,3 +47,7 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE=
gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=
gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=

View file

@ -13,6 +13,7 @@ import (
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
) )
const ( const (
@ -40,12 +41,9 @@ func HandleLoginCreate(kit *kit.Kit) error {
} }
var user User var user User
err := db.Query.NewSelect(). err := db.Get().Find(&user, "email = ?", values.Email).Error
Model(&user).
Where("user.email = ?", values.Email).
Scan(kit.Request.Context())
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == gorm.ErrRecordNotFound {
errors.Add("credentials", "invalid credentials") errors.Add("credentials", "invalid credentials")
return kit.Render(LoginForm(values, errors)) return kit.Render(LoginForm(values, errors))
} }
@ -59,7 +57,7 @@ func HandleLoginCreate(kit *kit.Kit) error {
skipVerify := kit.Getenv("SUPERKIT_AUTH_SKIP_VERIFY", "false") skipVerify := kit.Getenv("SUPERKIT_AUTH_SKIP_VERIFY", "false")
if skipVerify != "true" { if skipVerify != "true" {
if user.EmailVerifiedAt.Equal(time.Time{}) { if !user.EmailVerifiedAt.Valid {
errors.Add("verified", "please verify your email") errors.Add("verified", "please verify your email")
return kit.Render(LoginForm(values, errors)) return kit.Render(LoginForm(values, errors))
} }
@ -73,22 +71,15 @@ func HandleLoginCreate(kit *kit.Kit) error {
session := Session{ session := Session{
UserID: user.ID, UserID: user.ID,
Token: uuid.New().String(), Token: uuid.New().String(),
CreatedAt: time.Now(),
LastLoginAt: time.Now(),
ExpiresAt: time.Now().Add(time.Hour * time.Duration(sessionExpiry)), ExpiresAt: time.Now().Add(time.Hour * time.Duration(sessionExpiry)),
} }
_, err = db.Query.NewInsert(). if err = db.Get().Create(&session).Error; err != nil {
Model(&session).
Exec(kit.Request.Context())
if err != nil {
return err return err
} }
// TODO change this with kit.Getenv
sess := kit.GetSession(userSessionName) sess := kit.GetSession(userSessionName)
sess.Values["sessionToken"] = session.Token sess.Values["sessionToken"] = session.Token
sess.Save(kit.Request, kit.Response) sess.Save(kit.Request, kit.Response)
redirectURL := kit.Getenv("SUPERKIT_AUTH_REDIRECT_AFTER_LOGIN", "/profile") redirectURL := kit.Getenv("SUPERKIT_AUTH_REDIRECT_AFTER_LOGIN", "/profile")
return kit.Redirect(http.StatusSeeOther, redirectURL) return kit.Redirect(http.StatusSeeOther, redirectURL)
@ -100,10 +91,7 @@ func HandleLoginDelete(kit *kit.Kit) error {
sess.Values = map[any]any{} sess.Values = map[any]any{}
sess.Save(kit.Request, kit.Response) sess.Save(kit.Request, kit.Response)
}() }()
_, err := db.Query.NewDelete(). err := db.Get().Delete(&Session{}, "token = ?", sess.Values["sessionToken"]).Error
Model((*Session)(nil)).
Where("token = ?", sess.Values["sessionToken"]).
Exec(kit.Request.Context())
if err != nil { if err != nil {
return err return err
} }
@ -121,7 +109,7 @@ func HandleEmailVerify(kit *kit.Kit) error {
return []byte(os.Getenv("SUPERKIT_SECRET")), nil return []byte(os.Getenv("SUPERKIT_SECRET")), nil
}, jwt.WithLeeway(5*time.Second)) }, jwt.WithLeeway(5*time.Second))
if err != nil { if err != nil {
return err return kit.Render(EmailVerificationError("invalid verification token"))
} }
if !token.Valid { if !token.Valid {
return kit.Render(EmailVerificationError("invalid verification token")) return kit.Render(EmailVerificationError("invalid verification token"))
@ -141,23 +129,18 @@ func HandleEmailVerify(kit *kit.Kit) error {
} }
var user User var user User
err = db.Query.NewSelect(). err = db.Get().First(&user, userID).Error
Model(&user).
Where("id = ?", userID).
Scan(kit.Request.Context())
if err != nil { if err != nil {
return err return err
} }
if user.EmailVerifiedAt.After(time.Time{}) { if user.EmailVerifiedAt.Time.After(time.Time{}) {
return kit.Render(EmailVerificationError("Email already verified")) return kit.Render(EmailVerificationError("Email already verified"))
} }
user.EmailVerifiedAt = time.Now() now := sql.NullTime{Time: time.Now(), Valid: true}
_, err = db.Query.NewUpdate(). user.EmailVerifiedAt = now
Model(&user). err = db.Get().Save(&user).Error
WherePK().
Exec(kit.Request.Context())
if err != nil { if err != nil {
return err return err
} }
@ -174,16 +157,14 @@ func AuthenticateUser(kit *kit.Kit) (kit.Auth, error) {
} }
var session Session var session Session
err := db.Query.NewSelect(). err := db.Get().
Model(&session). Preload("User").
Relation("User"). Find(&session, "token = ? AND expires_at > ?", token, time.Now()).Error
Where("session.token = ? AND session.expires_at > ?", token, time.Now()). if err != nil || session.ID == 0 {
Scan(kit.Request.Context())
if err != nil {
return auth, nil return auth, nil
} }
return Auth{
return Auth{
LoggedIn: true, LoggedIn: true,
UserID: session.User.ID, UserID: session.User.ID,
Email: session.User.Email, Email: session.User.Email,

View file

@ -14,7 +14,7 @@ var profileSchema = v.Schema{
} }
type ProfileFormValues struct { type ProfileFormValues struct {
ID int `form:"id"` ID uint `form:"id"`
FirstName string `form:"firstName"` FirstName string `form:"firstName"`
LastName string `form:"lastName"` LastName string `form:"lastName"`
Email string Email string
@ -25,11 +25,7 @@ func HandleProfileShow(kit *kit.Kit) error {
auth := kit.Auth().(Auth) auth := kit.Auth().(Auth)
var user User var user User
err := db.Query.NewSelect(). if err := db.Get().First(&user, auth.UserID).Error; err != nil {
Model(&user).
Where("id = ?", auth.UserID).
Scan(kit.Request.Context())
if err != nil {
return err return err
} }
@ -54,12 +50,12 @@ func HandleProfileUpdate(kit *kit.Kit) error {
if auth.UserID != values.ID { if auth.UserID != values.ID {
return fmt.Errorf("unauthorized request for profile %d", values.ID) return fmt.Errorf("unauthorized request for profile %d", values.ID)
} }
_, err := db.Query.NewUpdate(). err := db.Get().Model(&User{}).
Model((*User)(nil)).
Set("first_name = ?", values.FirstName).
Set("last_name = ?", values.LastName).
Where("id = ?", auth.UserID). Where("id = ?", auth.UserID).
Exec(kit.Request.Context()) Updates(&User{
FirstName: values.FirstName,
LastName: values.LastName,
}).Error
if err != nil { if err != nil {
return err return err
} }

View file

@ -63,19 +63,15 @@ func HandleResendVerificationCode(kit *kit.Kit) error {
} }
var user User var user User
err = db.Query.NewSelect(). if err = db.Get().First(&user, id).Error; err != nil {
Model(&user).
Where("id = ?", id).
Scan(kit.Request.Context())
if err != nil {
return kit.Text(http.StatusOK, "An unexpected error occured") return kit.Text(http.StatusOK, "An unexpected error occured")
} }
if user.EmailVerifiedAt.After(time.Time{}) { if user.EmailVerifiedAt.Time.After(time.Time{}) {
return kit.Text(http.StatusOK, "Email already verified!") return kit.Text(http.StatusOK, "Email already verified!")
} }
token, err := createVerificationToken(id) token, err := createVerificationToken(uint(id))
if err != nil { if err != nil {
return kit.Text(http.StatusOK, "An unexpected error occured") return kit.Text(http.StatusOK, "An unexpected error occured")
} }
@ -90,7 +86,7 @@ func HandleResendVerificationCode(kit *kit.Kit) error {
return kit.Text(http.StatusOK, msg) return kit.Text(http.StatusOK, msg)
} }
func createVerificationToken(userID int) (string, error) { func createVerificationToken(userID uint) (string, error) {
expiryStr := kit.Getenv("SUPERKIT_AUTH_EMAIL_VERIFICATION_EXPIRY_IN_HOURS", "1") expiryStr := kit.Getenv("SUPERKIT_AUTH_EMAIL_VERIFICATION_EXPIRY_IN_HOURS", "1")
expiry, err := strconv.Atoi(expiryStr) expiry, err := strconv.Atoi(expiryStr)
if err != nil { if err != nil {

View file

@ -2,10 +2,11 @@ package auth
import ( import (
"AABBCCDD/app/db" "AABBCCDD/app/db"
"context" "database/sql"
"time" "time"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
) )
// Event name constants // Event name constants
@ -22,7 +23,7 @@ type UserWithVerificationToken struct {
} }
type Auth struct { type Auth struct {
UserID int UserID uint
Email string Email string
LoggedIn bool LoggedIn bool
} }
@ -32,12 +33,13 @@ func (auth Auth) Check() bool {
} }
type User struct { type User struct {
ID int `bun:"id,pk,autoincrement"` gorm.Model
Email string Email string
FirstName string FirstName string
LastName string LastName string
PasswordHash string PasswordHash string
EmailVerifiedAt time.Time EmailVerifiedAt sql.NullTime
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
} }
@ -52,22 +54,19 @@ func createUserFromFormValues(values SignupFormValues) (User, error) {
FirstName: values.FirstName, FirstName: values.FirstName,
LastName: values.LastName, LastName: values.LastName,
PasswordHash: string(hash), PasswordHash: string(hash),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
} }
_, err = db.Query.NewInsert().Model(&user).Exec(context.Background()) result := db.Get().Create(&user)
return user, err return user, result.Error
} }
type Session struct { type Session struct {
ID int `bun:"id,pk,autoincrement"` gorm.Model
UserID int
UserID uint
Token string Token string
IPAddress string IPAddress string
UserAgent string UserAgent string
ExpiresAt time.Time ExpiresAt time.Time
LastLoginAt time.Time
CreatedAt time.Time CreatedAt time.Time
User User
User User `bun:"rel:belongs-to,join:user_id=id"`
} }

View file

@ -7,6 +7,7 @@ import (
const ( const (
DriverSqlite3 = "sqlite3" DriverSqlite3 = "sqlite3"
DriverMysql = "mysql"
) )
type Config struct { type Config struct {
@ -17,7 +18,7 @@ type Config struct {
Password string Password string
} }
func New(cfg Config) (*sql.DB, error) { func NewSQL(cfg Config) (*sql.DB, error) {
switch cfg.Driver { switch cfg.Driver {
case DriverSqlite3: case DriverSqlite3:
name := cfg.Name name := cfg.Name

View file

@ -153,12 +153,18 @@ func parseRequest(r *http.Request, v any) error {
} }
case reflect.String: case reflect.String:
fieldVal.SetString(formValue) fieldVal.SetString(formValue)
case reflect.Int: case reflect.Int, reflect.Int32, reflect.Int64:
intVal, err := strconv.Atoi(formValue) intVal, err := strconv.Atoi(formValue)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse int: %v", err) return fmt.Errorf("failed to parse int: %v", err)
} }
fieldVal.SetInt(int64(intVal)) fieldVal.SetInt(int64(intVal))
case reflect.Uint, reflect.Uint32, reflect.Uint64:
intVal, err := strconv.Atoi(formValue)
if err != nil {
return fmt.Errorf("failed to parse int: %v", err)
}
fieldVal.SetUint(uint64(intVal))
case reflect.Float64: case reflect.Float64:
floatVal, err := strconv.ParseFloat(formValue, 64) floatVal, err := strconv.ParseFloat(formValue, 64)
if err != nil { if err != nil {