changed bun to gorm in auth plugin

This commit is contained in:
anthdm 2024-06-23 10:53:56 +02:00
parent 6028c5d60d
commit 6055fc6426
12 changed files with 109 additions and 97 deletions

View file

@ -5,21 +5,25 @@ import (
"os"
"github.com/anthdm/superkit/db"
"github.com/anthdm/superkit/kit"
_ "github.com/mattn/go-sqlite3"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/sqlitedialect"
"github.com/uptrace/bun/extra/bundebug"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
// I could not came up with a better naming for this.
// Ideally, app should export a global variable called "DB"
// but this will cause imports cycles for plugins.
var Query *bun.DB
// By default this is a pre-configured Gorm DB instance.
// Change this type based on the database package of your likings.
var dbInstance *gorm.DB
// Get returns the instantiated DB instance.
func Get() *gorm.DB {
return dbInstance
}
func init() {
// Create a default *sql.DB exposed by the superkit/db package
// based on the given configuration.
config := db.Config{
Driver: os.Getenv("DB_DRIVER"),
Name: os.Getenv("DB_NAME"),
@ -27,12 +31,30 @@ func init() {
User: os.Getenv("DB_USER"),
Host: os.Getenv("DB_HOST"),
}
db, err := db.New(config)
dbinst, err := db.NewSQL(config)
if err != nil {
log.Fatal(err)
}
Query = bun.NewDB(db, sqlitedialect.New())
if kit.IsDevelopment() {
Query.AddQueryHook(bundebug.NewQueryHook(bundebug.WithVerbose(true)))
// Based on the driver create the corresponding DB instance.
// By default, the SuperKit boilerplate comes with a pre-configured
// ORM called Gorm. https://gorm.io.
//
// You can change this to any other DB interaction tool
// of your liking. EG:
// - uptrace bun -> https://bun.uptrace.dev
// - SQLC -> https://github.com/sqlc-dev/sqlc
// - gojet -> https://github.com/go-jet/jet
switch config.Driver {
case db.DriverSqlite3:
dbInstance, err = gorm.Open(sqlite.New(sqlite.Config{
Conn: dbinst,
}))
case db.DriverMysql:
// ...
default:
log.Fatal("invalid driver:", config.Driver)
}
if err != nil {
log.Fatal(err)
}
}

View file

@ -5,9 +5,10 @@ create table if not exists users(
password_hash text not null,
first_name text not null,
last_name text not null,
email_verified_at timestamp with time zone,
created_at timestamp with time zone not null,
updated_at timestamp with time zone not null
email_verified_at datetime,
created_at datetime not null,
updated_at datetime not null,
deleted_at datetime
);
-- +goose Down

View file

@ -5,9 +5,10 @@ create table if not exists sessions(
user_id integer not null references users,
ip_address text,
user_agent text,
expires_at timestamp with time zone not null,
last_login_at timestamp with time zone,
created_at timestamp with time zone not null
expires_at datetime not null,
created_at datetime not null,
updated_at datetime not null,
deleted_at datetime
);
-- +goose Down

View file

@ -2,7 +2,7 @@ package types
// AuthUser represents an user that might be authenticated.
type AuthUser struct {
ID int
ID uint
Email string
LoggedIn bool
}

View file

@ -3,7 +3,7 @@ module AABBCCDD
go 1.22.4
// uncomment for local development on the superkit core.
// replace github.com/anthdm/superkit => ../
replace github.com/anthdm/superkit => ../
require (
github.com/a-h/templ v0.2.707
@ -24,10 +24,13 @@ require (
github.com/gorilla/securecookie v1.1.2 // indirect
github.com/gorilla/sessions v1.3.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/tmthrgd/go-hex v0.0.0-20190904060850-447a3041c3bc // indirect
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
golang.org/x/sys v0.21.0 // indirect
gorm.io/driver/sqlite v1.5.6 // indirect
gorm.io/gorm v1.25.10 // indirect
)

View file

@ -18,6 +18,8 @@ github.com/gorilla/sessions v1.3.0 h1:XYlkq7KcpOB2ZhHBPv5WpjMIxrQosiZanfoy1HLZFz
github.com/gorilla/sessions v1.3.0/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
@ -45,3 +47,7 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws=
golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE=
gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
gorm.io/gorm v1.25.10 h1:dQpO+33KalOA+aFYGlK+EfxcI5MbO7EP2yYygwh9h+s=
gorm.io/gorm v1.25.10/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=

View file

@ -13,6 +13,7 @@ import (
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
const (
@ -40,12 +41,9 @@ func HandleLoginCreate(kit *kit.Kit) error {
}
var user User
err := db.Query.NewSelect().
Model(&user).
Where("user.email = ?", values.Email).
Scan(kit.Request.Context())
err := db.Get().Find(&user, "email = ?", values.Email).Error
if err != nil {
if err == sql.ErrNoRows {
if err == gorm.ErrRecordNotFound {
errors.Add("credentials", "invalid credentials")
return kit.Render(LoginForm(values, errors))
}
@ -59,7 +57,7 @@ func HandleLoginCreate(kit *kit.Kit) error {
skipVerify := kit.Getenv("SUPERKIT_AUTH_SKIP_VERIFY", "false")
if skipVerify != "true" {
if user.EmailVerifiedAt.Equal(time.Time{}) {
if !user.EmailVerifiedAt.Valid {
errors.Add("verified", "please verify your email")
return kit.Render(LoginForm(values, errors))
}
@ -73,22 +71,15 @@ func HandleLoginCreate(kit *kit.Kit) error {
session := Session{
UserID: user.ID,
Token: uuid.New().String(),
CreatedAt: time.Now(),
LastLoginAt: time.Now(),
ExpiresAt: time.Now().Add(time.Hour * time.Duration(sessionExpiry)),
}
_, err = db.Query.NewInsert().
Model(&session).
Exec(kit.Request.Context())
if err != nil {
if err = db.Get().Create(&session).Error; err != nil {
return err
}
// TODO change this with kit.Getenv
sess := kit.GetSession(userSessionName)
sess.Values["sessionToken"] = session.Token
sess.Save(kit.Request, kit.Response)
redirectURL := kit.Getenv("SUPERKIT_AUTH_REDIRECT_AFTER_LOGIN", "/profile")
return kit.Redirect(http.StatusSeeOther, redirectURL)
@ -100,10 +91,7 @@ func HandleLoginDelete(kit *kit.Kit) error {
sess.Values = map[any]any{}
sess.Save(kit.Request, kit.Response)
}()
_, err := db.Query.NewDelete().
Model((*Session)(nil)).
Where("token = ?", sess.Values["sessionToken"]).
Exec(kit.Request.Context())
err := db.Get().Delete(&Session{}, "token = ?", sess.Values["sessionToken"]).Error
if err != nil {
return err
}
@ -121,7 +109,7 @@ func HandleEmailVerify(kit *kit.Kit) error {
return []byte(os.Getenv("SUPERKIT_SECRET")), nil
}, jwt.WithLeeway(5*time.Second))
if err != nil {
return err
return kit.Render(EmailVerificationError("invalid verification token"))
}
if !token.Valid {
return kit.Render(EmailVerificationError("invalid verification token"))
@ -141,23 +129,18 @@ func HandleEmailVerify(kit *kit.Kit) error {
}
var user User
err = db.Query.NewSelect().
Model(&user).
Where("id = ?", userID).
Scan(kit.Request.Context())
err = db.Get().First(&user, userID).Error
if err != nil {
return err
}
if user.EmailVerifiedAt.After(time.Time{}) {
if user.EmailVerifiedAt.Time.After(time.Time{}) {
return kit.Render(EmailVerificationError("Email already verified"))
}
user.EmailVerifiedAt = time.Now()
_, err = db.Query.NewUpdate().
Model(&user).
WherePK().
Exec(kit.Request.Context())
now := sql.NullTime{Time: time.Now(), Valid: true}
user.EmailVerifiedAt = now
err = db.Get().Save(&user).Error
if err != nil {
return err
}
@ -174,16 +157,14 @@ func AuthenticateUser(kit *kit.Kit) (kit.Auth, error) {
}
var session Session
err := db.Query.NewSelect().
Model(&session).
Relation("User").
Where("session.token = ? AND session.expires_at > ?", token, time.Now()).
Scan(kit.Request.Context())
if err != nil {
err := db.Get().
Preload("User").
Find(&session, "token = ? AND expires_at > ?", token, time.Now()).Error
if err != nil || session.ID == 0 {
return auth, nil
}
return Auth{
return Auth{
LoggedIn: true,
UserID: session.User.ID,
Email: session.User.Email,

View file

@ -14,7 +14,7 @@ var profileSchema = v.Schema{
}
type ProfileFormValues struct {
ID int `form:"id"`
ID uint `form:"id"`
FirstName string `form:"firstName"`
LastName string `form:"lastName"`
Email string
@ -25,11 +25,7 @@ func HandleProfileShow(kit *kit.Kit) error {
auth := kit.Auth().(Auth)
var user User
err := db.Query.NewSelect().
Model(&user).
Where("id = ?", auth.UserID).
Scan(kit.Request.Context())
if err != nil {
if err := db.Get().First(&user, auth.UserID).Error; err != nil {
return err
}
@ -54,12 +50,12 @@ func HandleProfileUpdate(kit *kit.Kit) error {
if auth.UserID != values.ID {
return fmt.Errorf("unauthorized request for profile %d", values.ID)
}
_, err := db.Query.NewUpdate().
Model((*User)(nil)).
Set("first_name = ?", values.FirstName).
Set("last_name = ?", values.LastName).
err := db.Get().Model(&User{}).
Where("id = ?", auth.UserID).
Exec(kit.Request.Context())
Updates(&User{
FirstName: values.FirstName,
LastName: values.LastName,
}).Error
if err != nil {
return err
}

View file

@ -63,19 +63,15 @@ func HandleResendVerificationCode(kit *kit.Kit) error {
}
var user User
err = db.Query.NewSelect().
Model(&user).
Where("id = ?", id).
Scan(kit.Request.Context())
if err != nil {
if err = db.Get().First(&user, id).Error; err != nil {
return kit.Text(http.StatusOK, "An unexpected error occured")
}
if user.EmailVerifiedAt.After(time.Time{}) {
if user.EmailVerifiedAt.Time.After(time.Time{}) {
return kit.Text(http.StatusOK, "Email already verified!")
}
token, err := createVerificationToken(id)
token, err := createVerificationToken(uint(id))
if err != nil {
return kit.Text(http.StatusOK, "An unexpected error occured")
}
@ -90,7 +86,7 @@ func HandleResendVerificationCode(kit *kit.Kit) error {
return kit.Text(http.StatusOK, msg)
}
func createVerificationToken(userID int) (string, error) {
func createVerificationToken(userID uint) (string, error) {
expiryStr := kit.Getenv("SUPERKIT_AUTH_EMAIL_VERIFICATION_EXPIRY_IN_HOURS", "1")
expiry, err := strconv.Atoi(expiryStr)
if err != nil {

View file

@ -2,10 +2,11 @@ package auth
import (
"AABBCCDD/app/db"
"context"
"database/sql"
"time"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// Event name constants
@ -22,7 +23,7 @@ type UserWithVerificationToken struct {
}
type Auth struct {
UserID int
UserID uint
Email string
LoggedIn bool
}
@ -32,12 +33,13 @@ func (auth Auth) Check() bool {
}
type User struct {
ID int `bun:"id,pk,autoincrement"`
gorm.Model
Email string
FirstName string
LastName string
PasswordHash string
EmailVerifiedAt time.Time
EmailVerifiedAt sql.NullTime
CreatedAt time.Time
UpdatedAt time.Time
}
@ -52,22 +54,19 @@ func createUserFromFormValues(values SignupFormValues) (User, error) {
FirstName: values.FirstName,
LastName: values.LastName,
PasswordHash: string(hash),
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
_, err = db.Query.NewInsert().Model(&user).Exec(context.Background())
return user, err
result := db.Get().Create(&user)
return user, result.Error
}
type Session struct {
ID int `bun:"id,pk,autoincrement"`
UserID int
gorm.Model
UserID uint
Token string
IPAddress string
UserAgent string
ExpiresAt time.Time
LastLoginAt time.Time
CreatedAt time.Time
User User `bun:"rel:belongs-to,join:user_id=id"`
User User
}

View file

@ -7,6 +7,7 @@ import (
const (
DriverSqlite3 = "sqlite3"
DriverMysql = "mysql"
)
type Config struct {
@ -17,7 +18,7 @@ type Config struct {
Password string
}
func New(cfg Config) (*sql.DB, error) {
func NewSQL(cfg Config) (*sql.DB, error) {
switch cfg.Driver {
case DriverSqlite3:
name := cfg.Name

View file

@ -153,12 +153,18 @@ func parseRequest(r *http.Request, v any) error {
}
case reflect.String:
fieldVal.SetString(formValue)
case reflect.Int:
case reflect.Int, reflect.Int32, reflect.Int64:
intVal, err := strconv.Atoi(formValue)
if err != nil {
return fmt.Errorf("failed to parse int: %v", err)
}
fieldVal.SetInt(int64(intVal))
case reflect.Uint, reflect.Uint32, reflect.Uint64:
intVal, err := strconv.Atoi(formValue)
if err != nil {
return fmt.Errorf("failed to parse int: %v", err)
}
fieldVal.SetUint(uint64(intVal))
case reflect.Float64:
floatVal, err := strconv.ParseFloat(formValue, 64)
if err != nil {