From 3bb8dde6708fe64b050f811cc5e08a1812d7fd79 Mon Sep 17 00:00:00 2001 From: Andrea Fazzi Date: Wed, 15 Jan 2020 11:27:00 +0100 Subject: [PATCH] Encapsulate orm behaviour in the Database struct --- handlers/handlers.go | 36 +++++++++-------- handlers/login.go | 46 +++++++++++----------- main.go | 22 ++++------- orm/answer.go | 40 +++++++++---------- orm/category.go | 40 +++++++++---------- orm/contest.go | 42 ++++++++++---------- orm/orm.go | 60 +++++++++++++++++------------ orm/participant.go | 92 ++++++++++++++++++++++---------------------- orm/question.go | 40 +++++++++---------- orm/region.go | 44 ++++++++++----------- orm/response.go | 46 +++++++++++----------- orm/school.go | 54 +++++++++++++------------- orm/user.go | 38 +++++++++--------- orm/useraction.go | 12 +++--- 14 files changed, 308 insertions(+), 304 deletions(-) diff --git a/handlers/handlers.go b/handlers/handlers.go index 78110f5f..18887863 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -35,12 +35,13 @@ type PathPattern struct { type Handlers struct { Config *config.ConfigT - Models []interface{} + Database *orm.Database + Models []interface{} - Login func(store *sessions.CookieStore, signingKey []byte) http.Handler + Login func(db *orm.Database, store *sessions.CookieStore, signingKey []byte) http.Handler Logout func(store *sessions.CookieStore) http.Handler Home func() http.Handler - GetToken func(signingKey []byte) http.Handler + GetToken func(db *orm.Database, signingKey []byte) http.Handler Static func() http.Handler Recover func(next http.Handler) http.Handler @@ -164,10 +165,11 @@ func (h *Handlers) generateModelHandlers(r *mux.Router, model interface{}) { } -func NewHandlers(config *config.ConfigT, models []interface{}) *Handlers { +func NewHandlers(config *config.ConfigT, db *orm.Database, models []interface{}) *Handlers { handlers := new(Handlers) handlers.Config = config + handlers.Database = db handlers.CookieStore = sessions.NewCookieStore([]byte(config.Keys.CookieStoreKey)) @@ -197,12 +199,12 @@ func NewHandlers(config *config.ConfigT, models []interface{}) *Handlers { // Authentication - r.Handle("/login", handlers.Login(handlers.CookieStore, []byte(config.Keys.JWTSigningKey))) + r.Handle("/login", handlers.Login(handlers.Database, handlers.CookieStore, []byte(config.Keys.JWTSigningKey))) r.Handle("/logout", handlers.Logout(handlers.CookieStore)) // School subscription - r.Handle("/subscribe", handlers.Login(handlers.CookieStore, []byte(config.Keys.JWTSigningKey))) + r.Handle("/subscribe", handlers.Login(handlers.Database, handlers.CookieStore, []byte(config.Keys.JWTSigningKey))) // Home @@ -216,7 +218,7 @@ func NewHandlers(config *config.ConfigT, models []interface{}) *Handlers { // Token handling - r.Handle("/get_token", handlers.GetToken([]byte(config.Keys.JWTSigningKey))) + r.Handle("/get_token", handlers.GetToken(handlers.Database, []byte(config.Keys.JWTSigningKey))) // Static file server @@ -293,7 +295,7 @@ func hasPermission(role, path string) bool { func (h *Handlers) get(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) { format := r.URL.Query().Get("format") - getFn, err := orm.GetFunc(pattern.Path(model)) + getFn, err := h.Database.GetFunc(pattern.Path(model)) if err != nil { log.Println("Error:", err) respondWithError(w, r, err) @@ -306,7 +308,7 @@ func (h *Handlers) get(w http.ResponseWriter, r *http.Request, model string, pat renderer.Render[format](w, r, fmt.Errorf("%s", "Errore di autorizzazione")) } else { - data, err := getFn(mux.Vars(r), w, r) + data, err := getFn(h.Database, mux.Vars(r), w, r) if err != nil { renderer.Render[format](w, r, err) } else { @@ -317,14 +319,14 @@ func (h *Handlers) get(w http.ResponseWriter, r *http.Request, model string, pat } -func post(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) { +func (h *Handlers) post(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) { var ( data interface{} err error ) respFormat := renderer.GetContentFormat(r) - postFn, err := orm.GetFunc(pattern.Path(model)) + postFn, err := h.Database.GetFunc(pattern.Path(model)) if err != nil { respondWithError(w, r, err) @@ -335,7 +337,7 @@ func post(w http.ResponseWriter, r *http.Request, model string, pattern PathPatt if !hasPermission(role, pattern.Path(model)) { renderer.Render[respFormat](w, r, fmt.Errorf("%s", "Errore di autorizzazione")) } else { - data, err = postFn(mux.Vars(r), w, r) + data, err = postFn(h.Database, mux.Vars(r), w, r) if err != nil { respondWithError(w, r, err) } else if pattern.RedirectPattern != "" { @@ -353,7 +355,7 @@ func post(w http.ResponseWriter, r *http.Request, model string, pattern PathPatt } -func delete(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) { +func (h *Handlers) delete(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) { var data interface{} respFormat := renderer.GetContentFormat(r) @@ -364,11 +366,11 @@ func delete(w http.ResponseWriter, r *http.Request, model string, pattern PathPa if !hasPermission(role, pattern.Path(model)) { renderer.Render[respFormat](w, r, fmt.Errorf("%s", "Errore di autorizzazione")) } else { - postFn, err := orm.GetFunc(pattern.Path(model)) + postFn, err := h.Database.GetFunc(pattern.Path(model)) if err != nil { renderer.Render[r.URL.Query().Get("format")](w, r, err) } - data, err = postFn(mux.Vars(r), w, r) + data, err = postFn(h.Database, mux.Vars(r), w, r) if err != nil { renderer.Render["html"](w, r, err) } else if pattern.RedirectPattern != "" { @@ -402,10 +404,10 @@ func (h *Handlers) modelHandler(model string, pattern PathPattern) http.Handler h.get(w, r, model, pattern) case "POST": - post(w, r, model, pattern) + h.post(w, r, model, pattern) case "DELETE": - delete(w, r, model, pattern) + h.delete(w, r, model, pattern) } } diff --git a/handlers/login.go b/handlers/login.go index f21e7d8e..53280146 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -59,14 +59,14 @@ func DefaultLogoutHandler(store *sessions.CookieStore) http.Handler { return http.HandlerFunc(fn) } -func DefaultLoginHandler(store *sessions.CookieStore, signingKey []byte) http.Handler { +func DefaultLoginHandler(db *orm.Database, store *sessions.CookieStore, signingKey []byte) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" { renderer.Render["html"](w, r, nil, r.URL.Query()) } if r.Method == "POST" { r.ParseForm() - token, err := getToken(r.FormValue("username"), r.FormValue("password"), signingKey) + token, err := getToken(db, r.FormValue("username"), r.FormValue("password"), signingKey) if err != nil { http.Redirect(w, r, "/login?tpl_layout=login&tpl_content=login&failed=true", http.StatusSeeOther) } else { @@ -87,9 +87,7 @@ func DefaultLoginHandler(store *sessions.CookieStore, signingKey []byte) http.Ha // FIXME: This is an hack for fast prototyping: users should have // their own table on DB. -func checkCredential(username string, password string) (*UserToken, error) { - var user orm.User - +func checkCredential(db *orm.Database, username string, password string) (*UserToken, error) { // Check if user is the administrator if username == config.Config.Admin.Username && password == config.Config.Admin.Password { @@ -104,31 +102,31 @@ func checkCredential(username string, password string) (*UserToken, error) { var token *UserToken - if err := orm.DB().Where("username = ? AND password = ?", username, password).First(&user).Error; err != nil { + user, err := db.GetUser(username, password) + if err != nil { return nil, errors.New("Authentication failed!") - } else { - switch user.Role { - case "participant": - var participant orm.Participant - if err := orm.DB().First(&participant, &orm.Participant{UserID: user.ID}).Error; err != nil { - return nil, errors.New("Authentication failed!") - } - token = &UserToken{username, false, user.Role, strconv.Itoa(int(participant.ID))} - case "school": - var school orm.School - if err := orm.DB().First(&school, &orm.School{UserID: user.ID}).Error; err != nil { - return nil, errors.New("Authentication failed!") - } - token = &UserToken{username, false, user.Role, strconv.Itoa(int(school.ID))} + } + switch user.Role { + case "participant": + var participant orm.Participant + if err := db.DB().First(&participant, &orm.Participant{UserID: user.ID}).Error; err != nil { + return nil, errors.New("Authentication failed!") } + token = &UserToken{username, false, user.Role, strconv.Itoa(int(participant.ID))} + case "school": + var school orm.School + if err := db.DB().First(&school, &orm.School{UserID: user.ID}).Error; err != nil { + return nil, errors.New("Authentication failed!") + } + token = &UserToken{username, false, user.Role, strconv.Itoa(int(school.ID))} } return token, nil } // FIXME: Refactor the functions above please!!! -func getToken(username string, password string, signingKey []byte) ([]byte, error) { - user, err := checkCredential(username, password) +func getToken(db *orm.Database, username string, password string, signingKey []byte) ([]byte, error) { + user, err := checkCredential(db, username, password) if err != nil { return nil, err } @@ -154,11 +152,11 @@ func getToken(username string, password string, signingKey []byte) ([]byte, erro } // FIXME: Refactor the functions above please!!! -func DefaultGetTokenHandler(signingKey []byte) http.Handler { +func DefaultGetTokenHandler(db *orm.Database, signingKey []byte) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { username, password, _ := r.BasicAuth() - token, err := getToken(username, password, signingKey) + token, err := getToken(db, username, password, signingKey) if err != nil { panic(err) } diff --git a/main.go b/main.go index 33cbc239..8b303d68 100644 --- a/main.go +++ b/main.go @@ -2,14 +2,12 @@ package main import ( "flag" - "fmt" "log" "net/http" "os" "time" "github.com/gorilla/handlers" - "github.com/jinzhu/gorm" "git.andreafazzi.eu/andrea/oef/config" oef_handlers "git.andreafazzi.eu/andrea/oef/handlers" @@ -22,7 +20,7 @@ const ( ) var ( - db *gorm.DB + db *orm.Database err error models = []interface{}{ @@ -54,7 +52,7 @@ func main() { wait := true for wait && count > 0 { - db, err = orm.New(fmt.Sprintf("%s?%s", config.Config.Orm.Connection, config.Config.Orm.Options)) + db, err = orm.NewDatabase(config.Config, models) if err != nil { count-- log.Println(err) @@ -65,26 +63,22 @@ func main() { wait = false } - orm.Use(db) + // REMOVE + // orm.Use(db) if config.Config.Orm.AutoMigrate { log.Print("Automigrating...") - orm.AutoMigrate(models...) + db.AutoMigrate() } log.Println("Eventually write categories on DB...") - orm.CreateCategories() + orm.CreateCategories(db) log.Println("Eventually write regions on DB...") - orm.CreateRegions() - - log.Println("Map models <-> handlers") - if err := orm.MapHandlers(models); err != nil { - panic(err) - } + orm.CreateRegions(db) log.Println("OEF is listening to port 3000...") - if err := http.ListenAndServe(":3000", handlers.LoggingHandler(os.Stdout, oef_handlers.NewHandlers(config.Config, models).Router)); err != nil { + if err := http.ListenAndServe(":3000", handlers.LoggingHandler(os.Stdout, oef_handlers.NewHandlers(config.Config, db, models).Router)); err != nil { panic(err) } diff --git a/orm/answer.go b/orm/answer.go index 833de0e8..83f4512e 100644 --- a/orm/answer.go +++ b/orm/answer.go @@ -26,10 +26,10 @@ func (a *Answer) String() string { return a.Text } -func (a *Answer) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (a *Answer) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { answer := new(Answer) - if err := DB().Find(&answer.AllQuestions).Error; err != nil { + if err := db._db.Find(&answer.AllQuestions).Error; err != nil { return nil, err } @@ -40,7 +40,7 @@ func (a *Answer) Create(args map[string]string, w http.ResponseWriter, r *http.R if err != nil { return nil, err } - answer, err = CreateAnswer(answer) + answer, err = CreateAnswer(db, answer) if err != nil { return nil, err } @@ -48,36 +48,36 @@ func (a *Answer) Create(args map[string]string, w http.ResponseWriter, r *http.R } } -func (a *Answer) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (a *Answer) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var answer Answer id := args["id"] - if err := DB().Preload("Question").First(&answer, id).Error; err != nil { + if err := db._db.Preload("Question").First(&answer, id).Error; err != nil { return nil, err } return &answer, nil } -func (a *Answer) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (a *Answer) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var answers []*Answer - if err := DB().Preload("Question").Order("created_at").Find(&answers).Error; err != nil { + if err := db._db.Preload("Question").Order("created_at").Find(&answers).Error; err != nil { return nil, err } return answers, nil } -func (a *Answer) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (a *Answer) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { - result, err := a.Read(args, w, r) + result, err := a.Read(db, args, w, r) if err != nil { return nil, err } answer := result.(*Answer) - if err := DB().Find(&answer.AllQuestions).Error; err != nil { + if err := db._db.Find(&answer.AllQuestions).Error; err != nil { return nil, err } @@ -86,7 +86,7 @@ func (a *Answer) Update(args map[string]string, w http.ResponseWriter, r *http.R return answer, nil } else { - answer, err := a.Read(args, w, r) + answer, err := a.Read(db, args, w, r) if err != nil { return nil, err } @@ -98,11 +98,11 @@ func (a *Answer) Update(args map[string]string, w http.ResponseWriter, r *http.R if err != nil { return nil, err } - _, err = SaveAnswer(answer) + _, err = SaveAnswer(db, answer) if err != nil { return nil, err } - answer, err = a.Read(args, w, r) + answer, err = a.Read(db, args, w, r) if err != nil { return nil, err } @@ -110,26 +110,26 @@ func (a *Answer) Update(args map[string]string, w http.ResponseWriter, r *http.R } } -func (model *Answer) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { - answer, err := model.Read(args, w, r) +func (model *Answer) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { + answer, err := model.Read(db, args, w, r) if err != nil { return nil, err } - if err := DB().Unscoped().Delete(answer.(*Answer)).Error; err != nil { + if err := db._db.Unscoped().Delete(answer.(*Answer)).Error; err != nil { return nil, err } return answer.(*Answer), nil } -func CreateAnswer(answer *Answer) (*Answer, error) { - if err := DB().Create(answer).Error; err != nil { +func CreateAnswer(db *Database, answer *Answer) (*Answer, error) { + if err := db._db.Create(answer).Error; err != nil { return nil, err } return answer, nil } -func SaveAnswer(answer interface{}) (interface{}, error) { - if err := DB().Omit("Answers").Save(answer).Error; err != nil { +func SaveAnswer(db *Database, answer interface{}) (interface{}, error) { + if err := db._db.Omit("Answers").Save(answer).Error; err != nil { return nil, err } return answer, nil diff --git a/orm/category.go b/orm/category.go index 41e858cd..12001679 100644 --- a/orm/category.go +++ b/orm/category.go @@ -23,10 +23,10 @@ func (model *Category) String() string { return model.Name } -func (model *Category) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Category) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { category := new(Category) - // if err := DB().Find(&category.AllContests).Error; err != nil { + // if err := db._db.Find(&category.AllContests).Error; err != nil { // return nil, err // } return category, nil @@ -36,7 +36,7 @@ func (model *Category) Create(args map[string]string, w http.ResponseWriter, r * if err != nil { return nil, err } - category, err = CreateCategory(category) + category, err = CreateCategory(db, category) if err != nil { return nil, err } @@ -44,36 +44,36 @@ func (model *Category) Create(args map[string]string, w http.ResponseWriter, r * } } -func (model *Category) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Category) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var category Category id := args["id"] - if err := DB(). /*.Preload("Something")*/ First(&category, id).Error; err != nil { + if err := db._db. /*.Preload("Something")*/ First(&category, id).Error; err != nil { return nil, err } return &category, nil } -func (model *Category) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Category) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var categories []*Category - if err := DB(). /*.Preload("Something")*/ Order("created_at").Find(&categories).Error; err != nil { + if err := db._db. /*.Preload("Something")*/ Order("created_at").Find(&categories).Error; err != nil { return nil, err } return categories, nil } -func (model *Category) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Category) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { - result, err := model.Read(args, w, r) + result, err := model.Read(db, args, w, r) if err != nil { return nil, err } category := result.(*Category) - // if err := DB().Find(&category.AllElements).Error; err != nil { + // if err := db._db.Find(&category.AllElements).Error; err != nil { // return nil, err // } @@ -82,7 +82,7 @@ func (model *Category) Update(args map[string]string, w http.ResponseWriter, r * return category, nil } else { - category, err := model.Read(args, w, r) + category, err := model.Read(db, args, w, r) if err != nil { return nil, err } @@ -90,11 +90,11 @@ func (model *Category) Update(args map[string]string, w http.ResponseWriter, r * if err != nil { return nil, err } - _, err = SaveCategory(category) + _, err = SaveCategory(db, category) if err != nil { return nil, err } - category, err = model.Read(args, w, r) + category, err = model.Read(db, args, w, r) if err != nil { return nil, err } @@ -102,26 +102,26 @@ func (model *Category) Update(args map[string]string, w http.ResponseWriter, r * } } -func (model *Category) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { - category, err := model.Read(args, w, r) +func (model *Category) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { + category, err := model.Read(db, args, w, r) if err != nil { return nil, err } - if err := DB().Unscoped().Delete(category.(*Category)).Error; err != nil { + if err := db._db.Unscoped().Delete(category.(*Category)).Error; err != nil { return nil, err } return category.(*Category), nil } -func CreateCategory(category *Category) (*Category, error) { - if err := DB().Create(category).Error; err != nil { +func CreateCategory(db *Database, category *Category) (*Category, error) { + if err := db._db.Create(category).Error; err != nil { return nil, err } return category, nil } -func SaveCategory(category interface{}) (interface{}, error) { - if err := DB(). /*.Omit("Something")*/ Save(category).Error; err != nil { +func SaveCategory(db *Database, category interface{}) (interface{}, error) { + if err := db._db. /*.Omit("Something")*/ Save(category).Error; err != nil { return nil, err } return category, nil diff --git a/orm/contest.go b/orm/contest.go index 3aff913d..308b1a66 100644 --- a/orm/contest.go +++ b/orm/contest.go @@ -39,7 +39,7 @@ func (c *Contest) String() string { return c.Name } -func (c *Contest) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (c *Contest) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { return nil, nil } else { @@ -68,7 +68,7 @@ func (c *Contest) Create(args map[string]string, w http.ResponseWriter, r *http. if err != nil { return nil, err } - contest, err = CreateContest(contest) + contest, err = CreateContest(db, contest) if err != nil { return nil, err } @@ -76,25 +76,25 @@ func (c *Contest) Create(args map[string]string, w http.ResponseWriter, r *http. } } -func (c *Contest) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (c *Contest) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var contest Contest id := args["id"] - if err := DB().Preload("Participants").Preload("Questions").First(&contest, id).Error; err != nil { + if err := db._db.Preload("Participants").Preload("Questions").First(&contest, id).Error; err != nil { return nil, err } return &contest, nil } -func (c *Contest) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (c *Contest) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var contests []*Contest claims := r.Context().Value("user").(*jwt.Token).Claims.(jwt.MapClaims) if claims["admin"].(bool) { - if err := DB().Order("created_at").Find(&contests).Error; err != nil { + if err := db._db.Order("created_at").Find(&contests).Error; err != nil { return nil, err } else { return contests, nil @@ -103,16 +103,16 @@ func (c *Contest) ReadAll(args map[string]string, w http.ResponseWriter, r *http participant := &Participant{} - if err := DB().Preload("Contests").Where("username = ?", claims["name"].(string)).First(&participant).Error; err != nil { + if err := db._db.Preload("Contests").Where("username = ?", claims["name"].(string)).First(&participant).Error; err != nil { return nil, err } else { return participant.Contests, nil } } -func (c *Contest) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (c *Contest) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { - result, err := c.Read(args, w, r) + result, err := c.Read(db, args, w, r) if err != nil { return nil, err } @@ -120,7 +120,7 @@ func (c *Contest) Update(args map[string]string, w http.ResponseWriter, r *http. return contest, nil } else { - contest, err := c.Read(args, w, r) + contest, err := c.Read(db, args, w, r) if err != nil { return nil, err } @@ -148,11 +148,11 @@ func (c *Contest) Update(args map[string]string, w http.ResponseWriter, r *http. if err != nil { return nil, err } - _, err = SaveContest(contest) + _, err = SaveContest(db, contest) if err != nil { return nil, err } - contest, err = c.Read(args, w, r) + contest, err = c.Read(db, args, w, r) if err != nil { return nil, err } @@ -160,26 +160,26 @@ func (c *Contest) Update(args map[string]string, w http.ResponseWriter, r *http. } } -func (model *Contest) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { - contest, err := model.Read(args, w, r) +func (model *Contest) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { + contest, err := model.Read(db, args, w, r) if err != nil { return nil, err } - if err := DB().Unscoped().Delete(contest.(*Contest)).Error; err != nil { + if err := db._db.Unscoped().Delete(contest.(*Contest)).Error; err != nil { return nil, err } return contest.(*Contest), nil } -func CreateContest(contest *Contest) (*Contest, error) { - if err := DB().Create(contest).Error; err != nil { +func CreateContest(db *Database, contest *Contest) (*Contest, error) { + if err := db._db.Create(contest).Error; err != nil { return nil, err } return contest, nil } -func SaveContest(contest interface{}) (interface{}, error) { - if err := DB().Omit("Contests").Save(contest).Error; err != nil { +func SaveContest(db *Database, contest interface{}) (interface{}, error) { + if err := db._db.Omit("Contests").Save(contest).Error; err != nil { return nil, err } return contest, nil @@ -189,13 +189,13 @@ func (c *Contest) isAlwaysActive() bool { return c.StartTime.IsZero() || c.EndTime.IsZero() || c.Duration == 0 } -func (c *Contest) generateQuestionsOrder() (string, error) { +func (c *Contest) generateQuestionsOrder(db *gorm.DB) (string, error) { var ( order []string questions []*Question ) - if err := DB().Find(&questions, Question{ContestID: c.ID}).Error; err != nil { + if err := db.Find(&questions, Question{ContestID: c.ID}).Error; err != nil { return "", err } diff --git a/orm/orm.go b/orm/orm.go index a1c4e251..ad7e6223 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -1,6 +1,7 @@ package orm import ( + "errors" "fmt" "net/http" "path" @@ -18,51 +19,60 @@ type IDer interface { GetID() uint } -type GetFn func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error) +type GetFn func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error) type Database struct { Config *config.ConfigT - _db *gorm.DB - fns map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error) + models []interface{} + _db *gorm.DB + fns map[string]func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error) } -func NewDatabase(config *config.ConfigT) (*Database, error) { +func NewDatabase(config *config.ConfigT, models []interface{}) (*Database, error) { + var err error + db := new(Database) - db.fns = make(db*gorm.DB, map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error), 0) - db._db, err := gorm.Open("mysql", fmt.Sprintf("%s?%s", config.Config.Orm.Connection, config.Config.Orm.Options)) + db.fns = make(map[string]func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error), 0) + db.mapHandlers(models) + + db._db, err = gorm.Open("mysql", fmt.Sprintf("%s?%s", config.Orm.Connection, config.Orm.Options)) if err != nil { return nil, err } - return db + return db, err } -func (db *Database) AutoMigrate(models ...interface{}) { - if err := db._db.AutoMigrate(models...).Error; err != nil { +func (db *Database) AutoMigrate() { + if err := db._db.AutoMigrate(db.models...).Error; err != nil { panic(err) } } -func CreateCategories() { +func (db *Database) GetUser(username, password string) (*User, error) { + var user *User + if err := db._db.Where("username = ? AND password = ?", username, password).First(&user).Error; err != nil { + return nil, errors.New("Authentication failed!") + } + return user, nil +} + +func (db *Database) DB() *gorm.DB { + return db._db +} + +func CreateCategories(db *Database) { for _, name := range categories { var category Category - if err := currDB.FirstOrCreate(&category, Category{Name: name}).Error; err != nil { + if err := db._db.FirstOrCreate(&category, Category{Name: name}).Error; err != nil { panic(err) } } } -func Use(db *gorm.DB) { - currDB = db -} - -func DB() *gorm.DB { - return currDB -} - -func MapHandlers(models []interface{}) error { +func (db *Database) mapHandlers(models []interface{}) error { for _, model := range models { name := inflection.Plural(strings.ToLower(ModelName(model))) for p, action := range map[string]string{ @@ -80,7 +90,7 @@ func MapHandlers(models []interface{}) error { if strings.HasSuffix(p, "/") { joinedPath += "/" } - fns[joinedPath] = method.Interface().(func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)) + db.fns[joinedPath] = method.Interface().(func(*Database, map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)) } } @@ -88,19 +98,19 @@ func MapHandlers(models []interface{}) error { return nil } -func GetFunc(path string) (GetFn, error) { - fn, ok := fns[path] +func (db *Database) GetFunc(path string) (GetFn, error) { + fn, ok := db.fns[path] if !ok { return nil, fmt.Errorf("Can't map path %s to any model methods.", path) } return fn, nil } -func GetNothing(args map[string]string) (interface{}, error) { +func GetNothing(db *Database, args map[string]string) (interface{}, error) { return nil, nil } -func PostNothing(args map[string]string, w http.ResponseWriter, r *http.Request) (IDer, error) { +func PostNothing(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (IDer, error) { return nil, nil } diff --git a/orm/participant.go b/orm/participant.go index bec4fb83..67a62742 100644 --- a/orm/participant.go +++ b/orm/participant.go @@ -92,9 +92,9 @@ func (model *Participant) String() string { // return nil // } -func (model *Participant) exists() (*User, error) { +func (model *Participant) exists(db *Database) (*User, error) { var user User - if err := DB().First(&user, &User{Username: model.username()}).Error; err != nil && err != gorm.ErrRecordNotFound { + if err := db._db.First(&user, &User{Username: model.username()}).Error; err != nil && err != gorm.ErrRecordNotFound { return nil, err } else if err == gorm.ErrRecordNotFound { return nil, nil @@ -129,7 +129,7 @@ func (model *Participant) AfterSave(tx *gorm.DB) error { return err } - order, err := contest.generateQuestionsOrder() + order, err := contest.generateQuestionsOrder(tx) if err != nil { return err } @@ -148,21 +148,21 @@ func (model *Participant) AfterDelete(tx *gorm.DB) error { return nil } -func (model *Participant) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Participant) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { participant := new(Participant) if isSchool(r) { - if err := DB().Find(&participant.AllCategories).Error; err != nil { + if err := db._db.Find(&participant.AllCategories).Error; err != nil { return nil, err } } else { - if err := DB().Find(&participant.AllCategories).Error; err != nil { + if err := db._db.Find(&participant.AllCategories).Error; err != nil { return nil, err } - if err := DB().Find(&participant.AllContests).Error; err != nil { + if err := db._db.Find(&participant.AllContests).Error; err != nil { return nil, err } - if err := DB().Find(&participant.AllSchools).Error; err != nil { + if err := db._db.Find(&participant.AllSchools).Error; err != nil { return nil, err } } @@ -175,8 +175,8 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter, } // Check if participant exists - if user, err := participant.exists(); err == nil && user != nil { - if err := DB().Where("user_id = ?", user.ID).Find(&participant).Error; err != nil { + if user, err := participant.exists(db); err == nil && user != nil { + if err := db._db.Where("user_id = ?", user.ID).Find(&participant).Error; err != nil { return nil, err } // err := setFlashMessage(w, r, "participantExists") @@ -199,10 +199,10 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter, // Check if a participant of the same category exists var school School - if err := DB().First(&school, participant.SchoolID).Error; err != nil { + if err := db._db.First(&school, participant.SchoolID).Error; err != nil { return nil, err } - hasCategory, err := school.HasCategory(participant) + hasCategory, err := school.HasCategory(db, participant) if err != nil { return nil, err } @@ -212,20 +212,20 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter, participant.UserModifierCreate = NewUserModifierCreate(r) - participant, err = CreateParticipant(participant) + participant, err = CreateParticipant(db, participant) if err != nil { return nil, err } var response Response - if err := DB().First(&response, &Response{ParticipantID: participant.ID}).Error; err != nil { + if err := db._db.First(&response, &Response{ParticipantID: participant.ID}).Error; err != nil { return nil, err } response.UserModifierCreate = NewUserModifierCreate(r) - if err := DB().Save(&response).Error; err != nil { + if err := db._db.Save(&response).Error; err != nil { return nil, err } @@ -233,14 +233,14 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter, } } -func (model *Participant) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Participant) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var participant Participant id := args["id"] // School user can access to its participants only! if isSchool(r) { - if err := DB().Preload("School").First(&participant, id).Error; err != nil { + if err := db._db.Preload("School").First(&participant, id).Error; err != nil { return nil, err } @@ -249,12 +249,12 @@ func (model *Participant) Read(args map[string]string, w http.ResponseWriter, r return nil, errors.NotAuthorized } - if err := DB().Preload("User").Preload("School").Preload("Category").First(&participant, id).Error; err != nil { + if err := db._db.Preload("User").Preload("School").Preload("Category").First(&participant, id).Error; err != nil { return nil, err } } else { - if err := DB().Preload("User").Preload("School").Preload("Responses").Preload("Contests").Preload("Category").First(&participant, id).Error; err != nil { + if err := db._db.Preload("User").Preload("School").Preload("Responses").Preload("Contests").Preload("Category").First(&participant, id).Error; err != nil { return nil, err } } @@ -262,7 +262,7 @@ func (model *Participant) Read(args map[string]string, w http.ResponseWriter, r return &participant, nil } -func (model *Participant) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Participant) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var participants []*Participant // School user can access to its participants only! @@ -271,20 +271,20 @@ func (model *Participant) ReadAll(args map[string]string, w http.ResponseWriter, if err != nil { return nil, err } - if err := DB().Preload("Category").Preload("School").Preload("Contests").Order("lastname").Find(&participants, &Participant{SchoolID: uint(schoolId)}).Error; err != nil { + if err := db._db.Preload("Category").Preload("School").Preload("Contests").Order("lastname").Find(&participants, &Participant{SchoolID: uint(schoolId)}).Error; err != nil { return nil, err } } else { - if err := DB().Preload("School").Preload("Contests").Preload("Responses").Order("created_at").Find(&participants).Error; err != nil { + if err := db._db.Preload("School").Preload("Contests").Preload("Responses").Order("created_at").Find(&participants).Error; err != nil { return nil, err } } return participants, nil } -func (model *Participant) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Participant) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { - result, err := model.Read(args, w, r) + result, err := model.Read(db, args, w, r) if err != nil { return nil, err } @@ -292,21 +292,21 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter, participant := result.(*Participant) if isSchool(r) { - if err := DB().Find(&participant.AllCategories).Error; err != nil { + if err := db._db.Find(&participant.AllCategories).Error; err != nil { return nil, err } participant.SelectedCategory = make(map[uint]string) participant.SelectedCategory[participant.CategoryID] = "selected" } else { - if err := DB().Find(&participant.AllCategories).Error; err != nil { + if err := db._db.Find(&participant.AllCategories).Error; err != nil { return nil, err } participant.SelectedCategory = make(map[uint]string) participant.SelectedCategory[participant.CategoryID] = "selected" - if err := DB().Find(&participant.AllContests).Error; err != nil { + if err := db._db.Find(&participant.AllContests).Error; err != nil { return nil, err } @@ -315,7 +315,7 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter, participant.SelectedContest[c.ID] = "selected" } - if err := DB().Find(&participant.AllSchools).Error; err != nil { + if err := db._db.Find(&participant.AllSchools).Error; err != nil { return nil, err } @@ -324,7 +324,7 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter, } return participant, nil } else { - participant, err := model.Read(args, w, r) + participant, err := model.Read(db, args, w, r) if err != nil { return nil, err } @@ -333,7 +333,7 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter, return nil, err } - if user, err := participant.(*Participant).exists(); err == nil && user != nil { + if user, err := participant.(*Participant).exists(db); err == nil && user != nil { if user.ID != participant.(*Participant).UserID { // err := setFlashMessage(w, r, "participantExists") // if err != nil { @@ -347,10 +347,10 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter, // Check if a participant of the same category exists var school School - if err := DB().First(&school, participant.(*Participant).SchoolID).Error; err != nil { + if err := db._db.First(&school, participant.(*Participant).SchoolID).Error; err != nil { return nil, err } - hasCategory, err := school.HasCategory(participant.(*Participant)) + hasCategory, err := school.HasCategory(db, participant.(*Participant)) if err != nil { return nil, err } @@ -358,35 +358,35 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter, return nil, errors.CategoryExists } - if err := DB().Where(participant.(*Participant).ContestIDs).Find(&participant.(*Participant).Contests).Error; err != nil { + if err := db._db.Where(participant.(*Participant).ContestIDs).Find(&participant.(*Participant).Contests).Error; err != nil { return nil, err } participant.(*Participant).UserModifierUpdate = NewUserModifierUpdate(r) - _, err = SaveParticipant(participant) + _, err = SaveParticipant(db, participant) if err != nil { return nil, err } - if err := DB().Model(participant).Association("Contests").Replace(participant.(*Participant).Contests).Error; err != nil { + if err := db._db.Model(participant).Association("Contests").Replace(participant.(*Participant).Contests).Error; err != nil { return nil, err } - participant, err = model.Read(args, w, r) + participant, err = model.Read(db, args, w, r) if err != nil { return nil, err } var response Response - if err := DB().First(&response, &Response{ParticipantID: participant.(*Participant).ID}).Error; err != nil { + if err := db._db.First(&response, &Response{ParticipantID: participant.(*Participant).ID}).Error; err != nil { return nil, err } response.UserModifierUpdate = NewUserModifierUpdate(r) - if err := DB().Save(&response).Error; err != nil { + if err := db._db.Save(&response).Error; err != nil { return nil, err } @@ -394,31 +394,31 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter, } } -func (model *Participant) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { - participant, err := model.Read(args, w, r) +func (model *Participant) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { + participant, err := model.Read(db, args, w, r) if err != nil { return nil, err } - if err := DB().Unscoped().Delete(participant.(*Participant)).Error; err != nil { + if err := db._db.Unscoped().Delete(participant.(*Participant)).Error; err != nil { return nil, err } return participant.(*Participant), nil } -func CreateParticipant(participant *Participant) (*Participant, error) { - if err := DB().Where([]uint(participant.ContestIDs)).Find(&participant.Contests).Error; err != nil { +func CreateParticipant(db *Database, participant *Participant) (*Participant, error) { + if err := db._db.Where([]uint(participant.ContestIDs)).Find(&participant.Contests).Error; err != nil { return nil, err } - if err := DB().Create(participant).Error; err != nil { + if err := db._db.Create(participant).Error; err != nil { return nil, err } return participant, nil } -func SaveParticipant(participant interface{}) (interface{}, error) { +func SaveParticipant(db *Database, participant interface{}) (interface{}, error) { participant.(*Participant).FiscalCode = strings.ToUpper(participant.(*Participant).FiscalCode) - if err := DB().Omit("Category", "School").Save(participant).Error; err != nil { + if err := db._db.Omit("Category", "School").Save(participant).Error; err != nil { return nil, err } return participant, nil diff --git a/orm/question.go b/orm/question.go index 5c7bd97b..c3ac111b 100644 --- a/orm/question.go +++ b/orm/question.go @@ -33,10 +33,10 @@ func (q *Question) BeforeCreate(tx *gorm.DB) error { return nil } -func (q *Question) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (q *Question) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { question := new(Question) - if err := DB().Find(&question.AllContests).Error; err != nil { + if err := db._db.Find(&question.AllContests).Error; err != nil { return nil, err } return question, nil @@ -46,7 +46,7 @@ func (q *Question) Create(args map[string]string, w http.ResponseWriter, r *http if err != nil { return nil, err } - question, err = CreateQuestion(question) + question, err = CreateQuestion(db, question) if err != nil { return nil, err } @@ -54,12 +54,12 @@ func (q *Question) Create(args map[string]string, w http.ResponseWriter, r *http } } -func (q *Question) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (q *Question) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var question Question id := args["id"] - if err := DB().Preload("Answers", func(db *gorm.DB) *gorm.DB { + if err := db._db.Preload("Answers", func(db *gorm.DB) *gorm.DB { return db.Order("answers.correct DESC") }).Preload("Contest").First(&question, id).Error; err != nil { return nil, err @@ -68,24 +68,24 @@ func (q *Question) Read(args map[string]string, w http.ResponseWriter, r *http.R return &question, nil } -func (q *Question) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (q *Question) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var questions []*Question - if err := DB().Preload("Contest").Order("created_at").Find(&questions).Error; err != nil { + if err := db._db.Preload("Contest").Order("created_at").Find(&questions).Error; err != nil { return nil, err } return questions, nil } -func (q *Question) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (q *Question) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { - result, err := q.Read(args, w, r) + result, err := q.Read(db, args, w, r) if err != nil { return nil, err } question := result.(*Question) - if err := DB().Find(&question.AllContests).Error; err != nil { + if err := db._db.Find(&question.AllContests).Error; err != nil { return nil, err } @@ -94,7 +94,7 @@ func (q *Question) Update(args map[string]string, w http.ResponseWriter, r *http return question, nil } else { - question, err := q.Read(args, w, r) + question, err := q.Read(db, args, w, r) if err != nil { return nil, err } @@ -102,11 +102,11 @@ func (q *Question) Update(args map[string]string, w http.ResponseWriter, r *http if err != nil { return nil, err } - _, err = SaveQuestion(question) + _, err = SaveQuestion(db, question) if err != nil { return nil, err } - question, err = q.Read(args, w, r) + question, err = q.Read(db, args, w, r) if err != nil { return nil, err } @@ -114,26 +114,26 @@ func (q *Question) Update(args map[string]string, w http.ResponseWriter, r *http } } -func (model *Question) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { - question, err := model.Read(args, w, r) +func (model *Question) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { + question, err := model.Read(db, args, w, r) if err != nil { return nil, err } - if err := DB().Unscoped().Delete(question.(*Question)).Error; err != nil { + if err := db._db.Unscoped().Delete(question.(*Question)).Error; err != nil { return nil, err } return question.(*Question), nil } -func CreateQuestion(question *Question) (*Question, error) { - if err := DB().Create(question).Error; err != nil { +func CreateQuestion(db *Database, question *Question) (*Question, error) { + if err := db._db.Create(question).Error; err != nil { return nil, err } return question, nil } -func SaveQuestion(question interface{}) (interface{}, error) { - if err := DB().Omit("Answers", "Contest").Save(question).Error; err != nil { +func SaveQuestion(db *Database, question interface{}) (interface{}, error) { + if err := db._db.Omit("Answers", "Contest").Save(question).Error; err != nil { return nil, err } return question, nil diff --git a/orm/region.go b/orm/region.go index 3080be44..cff5198d 100644 --- a/orm/region.go +++ b/orm/region.go @@ -60,10 +60,10 @@ type Region struct { Name string } -func CreateRegions() { +func CreateRegions(db *Database) { for _, name := range regions { var region Region - if err := currDB.FirstOrCreate(®ion, Region{Name: name}).Error; err != nil { + if err := db._db.FirstOrCreate(®ion, Region{Name: name}).Error; err != nil { panic(err) } } @@ -75,10 +75,10 @@ func (model *Region) String() string { return model.Name } -func (model *Region) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Region) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { region := new(Region) - // if err := DB().Find(®ion.AllContests).Error; err != nil { + // if err := db._db.Find(®ion.AllContests).Error; err != nil { // return nil, err // } return region, nil @@ -88,7 +88,7 @@ func (model *Region) Create(args map[string]string, w http.ResponseWriter, r *ht if err != nil { return nil, err } - region, err = CreateRegion(region) + region, err = CreateRegion(db, region) if err != nil { return nil, err } @@ -96,36 +96,36 @@ func (model *Region) Create(args map[string]string, w http.ResponseWriter, r *ht } } -func (model *Region) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Region) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var region Region id := args["id"] - if err := DB(). /*.Preload("Something")*/ First(®ion, id).Error; err != nil { + if err := db._db. /*.Preload("Something")*/ First(®ion, id).Error; err != nil { return nil, err } return ®ion, nil } -func (model *Region) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Region) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var regions []*Region - if err := DB(). /*.Preload("Something")*/ Order("created_at").Find(®ions).Error; err != nil { + if err := db._db. /*.Preload("Something")*/ Order("created_at").Find(®ions).Error; err != nil { return nil, err } return regions, nil } -func (model *Region) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Region) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { - result, err := model.Read(args, w, r) + result, err := model.Read(db, args, w, r) if err != nil { return nil, err } region := result.(*Region) - // if err := DB().Find(®ion.AllElements).Error; err != nil { + // if err := db._db.Find(®ion.AllElements).Error; err != nil { // return nil, err // } @@ -134,7 +134,7 @@ func (model *Region) Update(args map[string]string, w http.ResponseWriter, r *ht return region, nil } else { - region, err := model.Read(args, w, r) + region, err := model.Read(db, args, w, r) if err != nil { return nil, err } @@ -142,11 +142,11 @@ func (model *Region) Update(args map[string]string, w http.ResponseWriter, r *ht if err != nil { return nil, err } - _, err = SaveRegion(region) + _, err = SaveRegion(db, region) if err != nil { return nil, err } - region, err = model.Read(args, w, r) + region, err = model.Read(db, args, w, r) if err != nil { return nil, err } @@ -154,26 +154,26 @@ func (model *Region) Update(args map[string]string, w http.ResponseWriter, r *ht } } -func (model *Region) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { - region, err := model.Read(args, w, r) +func (model *Region) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { + region, err := model.Read(db, args, w, r) if err != nil { return nil, err } - if err := DB().Unscoped().Delete(region.(*Region)).Error; err != nil { + if err := db._db.Unscoped().Delete(region.(*Region)).Error; err != nil { return nil, err } return region.(*Region), nil } -func CreateRegion(region *Region) (*Region, error) { - if err := DB().Create(region).Error; err != nil { +func CreateRegion(db *Database, region *Region) (*Region, error) { + if err := db._db.Create(region).Error; err != nil { return nil, err } return region, nil } -func SaveRegion(region interface{}) (interface{}, error) { - if err := DB(). /*.Omit("Something")*/ Save(region).Error; err != nil { +func SaveRegion(db *Database, region interface{}) (interface{}, error) { + if err := db._db. /*.Omit("Something")*/ Save(region).Error; err != nil { return nil, err } return region, nil diff --git a/orm/response.go b/orm/response.go index c6af599a..3474a6cf 100644 --- a/orm/response.go +++ b/orm/response.go @@ -77,13 +77,13 @@ func (model *Response) BeforeSave(tx *gorm.DB) error { return nil } -func (model *Response) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Response) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { response := new(Response) contestID := r.URL.Query().Get("contest_id") - if err := DB().Preload("Answers").Where("contest_id = ?", contestID).Find(&response.Questions).Error; err != nil { + if err := db._db.Preload("Answers").Where("contest_id = ?", contestID).Find(&response.Questions).Error; err != nil { return nil, err } return response, nil @@ -96,7 +96,7 @@ func (model *Response) Create(args map[string]string, w http.ResponseWriter, r * response.UserModifierCreate = NewUserModifierCreate(r) - response, err = CreateResponse(response) + response, err = CreateResponse(db, response) if err != nil { return nil, err } @@ -104,12 +104,12 @@ func (model *Response) Create(args map[string]string, w http.ResponseWriter, r * } } -func (model *Response) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Response) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var response Response id := args["id"] - if err := DB().Preload("Contest").Preload("Participant").First(&response, id).Error; err != nil { + if err := db._db.Preload("Contest").Preload("Participant").First(&response, id).Error; err != nil { return nil, err } @@ -126,7 +126,7 @@ func (model *Response) Read(args map[string]string, w http.ResponseWriter, r *ht for _, sr := range response.SingleResponses { var answer Answer - if err := DB().First(&answer, sr.AnswerID).Error; err != nil { + if err := db._db.First(&answer, sr.AnswerID).Error; err != nil { return nil, err } @@ -151,7 +151,7 @@ func (model *Response) Read(args map[string]string, w http.ResponseWriter, r *ht // Fetch questions in the given order field := fmt.Sprintf("FIELD(id,%s)", strings.Replace(response.QuestionsOrder, " ", ",", -1)) - if err := DB().Order(field).Where("contest_id = ?", response.Contest.ID).Preload("Answers", func(db *gorm.DB) *gorm.DB { + if err := db._db.Order(field).Where("contest_id = ?", response.Contest.ID).Preload("Answers", func(db *gorm.DB) *gorm.DB { return db.Order("RAND()") }).Find(&response.Questions).Error; err != nil { return nil, err @@ -160,17 +160,17 @@ func (model *Response) Read(args map[string]string, w http.ResponseWriter, r *ht return &response, nil } -func (model *Response) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Response) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var responses []*Response - if err := DB().Preload("Contest").Preload("Participant").Order("created_at").Find(&responses).Error; err != nil { + if err := db._db.Preload("Contest").Preload("Participant").Order("created_at").Find(&responses).Error; err != nil { return nil, err } return responses, nil } -func (model *Response) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *Response) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { - result, err := model.Read(args, w, r) + result, err := model.Read(db, args, w, r) if err != nil { return nil, err } @@ -184,10 +184,10 @@ func (model *Response) Update(args map[string]string, w http.ResponseWriter, r * return nil, errors.OutOfTime } if response.StartTime.IsZero() { - if err := DB().Model(&response).Update("start_time", time.Now()).Error; err != nil { + if err := db._db.Model(&response).Update("start_time", time.Now()).Error; err != nil { return nil, err } - if err := DB().Model(&response).Update("end_time", time.Now().Add(time.Duration(response.Contest.Duration)*time.Minute)).Error; err != nil { + if err := db._db.Model(&response).Update("end_time", time.Now().Add(time.Duration(response.Contest.Duration)*time.Minute)).Error; err != nil { return nil, err } log.Println("StartTime/EndTime", response.StartTime, response.EndTime) @@ -197,7 +197,7 @@ func (model *Response) Update(args map[string]string, w http.ResponseWriter, r * return response, nil } else { - response, err := model.Read(args, w, r) + response, err := model.Read(db, args, w, r) if err != nil { return nil, err } @@ -214,11 +214,11 @@ func (model *Response) Update(args map[string]string, w http.ResponseWriter, r * response.(*Response).UserModifierUpdate = NewUserModifierUpdate(r) - _, err = SaveResponse(response) + _, err = SaveResponse(db, response) if err != nil { return nil, err } - response, err = model.Read(args, w, r) + response, err = model.Read(db, args, w, r) if err != nil { return nil, err } @@ -226,26 +226,26 @@ func (model *Response) Update(args map[string]string, w http.ResponseWriter, r * } } -func (model *Response) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { - response, err := model.Read(args, w, r) +func (model *Response) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { + response, err := model.Read(db, args, w, r) if err != nil { return nil, err } - if err := DB().Unscoped().Delete(response.(*Response)).Error; err != nil { + if err := db._db.Unscoped().Delete(response.(*Response)).Error; err != nil { return nil, err } return response.(*Response), nil } -func CreateResponse(response *Response) (*Response, error) { - if err := DB().Create(response).Error; err != nil { +func CreateResponse(db *Database, response *Response) (*Response, error) { + if err := db._db.Create(response).Error; err != nil { return nil, err } return response, nil } -func SaveResponse(response interface{}) (interface{}, error) { - if err := DB(). /*.Omit("Something")*/ Save(response).Error; err != nil { +func SaveResponse(db *Database, response interface{}) (interface{}, error) { + if err := db._db. /*.Omit("Something")*/ Save(response).Error; err != nil { return nil, err } return response, nil diff --git a/orm/school.go b/orm/school.go index 33d90060..6646b113 100644 --- a/orm/school.go +++ b/orm/school.go @@ -70,9 +70,9 @@ func (model *School) To() string { return model.Email } -func (model *School) exists() (*User, error) { +func (model *School) exists(db *Database) (*User, error) { var user User - if err := DB().First(&user, &User{Username: model.Username()}).Error; err != nil && err != gorm.ErrRecordNotFound { + if err := db._db.First(&user, &User{Username: model.Username()}).Error; err != nil && err != gorm.ErrRecordNotFound { return nil, err } else if err == gorm.ErrRecordNotFound { return nil, nil @@ -118,11 +118,11 @@ func (model *School) AfterCreate(tx *gorm.DB) error { return nil } -func (model *School) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *School) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { school := new(School) - if err := DB().Find(&school.AllRegions).Error; err != nil { + if err := db._db.Find(&school.AllRegions).Error; err != nil { return nil, err } @@ -135,8 +135,8 @@ func (model *School) Create(args map[string]string, w http.ResponseWriter, r *ht } // Check if this user already exists in the users table. - if user, err := school.exists(); err == nil && user != nil { - if err := DB().Where("user_id = ?", user.ID).Find(&school).Error; err != nil { + if user, err := school.exists(db); err == nil && user != nil { + if err := db._db.Where("user_id = ?", user.ID).Find(&school).Error; err != nil { return nil, err } // err := setFlashMessage(w, r, "schoolExists") @@ -150,7 +150,7 @@ func (model *School) Create(args map[string]string, w http.ResponseWriter, r *ht school.UserModifierCreate = NewUserModifierCreate(r) - school, err = CreateSchool(school) + school, err = CreateSchool(db, school) if err != nil { return nil, err } @@ -158,7 +158,7 @@ func (model *School) Create(args map[string]string, w http.ResponseWriter, r *ht } } -func (model *School) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *School) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var school School id := args["id"] @@ -168,31 +168,31 @@ func (model *School) Read(args map[string]string, w http.ResponseWriter, r *http return nil, errors.NotAuthorized } - if err := DB().Preload("User").Preload("Region").Preload("Participants.Category").Preload("Participants").First(&school, id).Error; err != nil { + if err := db._db.Preload("User").Preload("Region").Preload("Participants.Category").Preload("Participants").First(&school, id).Error; err != nil { return nil, err } return &school, nil } -func (model *School) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *School) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var schools []*School - if err := DB().Preload("Region").Preload("Participants.Category").Preload("Participants").Order("code").Find(&schools).Error; err != nil { + if err := db._db.Preload("Region").Preload("Participants.Category").Preload("Participants").Order("code").Find(&schools).Error; err != nil { return nil, err } return schools, nil } -func (model *School) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *School) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { - result, err := model.Read(args, w, r) + result, err := model.Read(db, args, w, r) if err != nil { return nil, err } school := result.(*School) - if err := DB().Find(&school.AllRegions).Error; err != nil { + if err := db._db.Find(&school.AllRegions).Error; err != nil { return nil, err } @@ -201,7 +201,7 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht return school, nil } else { - school, err := model.Read(args, w, r) + school, err := model.Read(db, args, w, r) if err != nil { return nil, err } @@ -211,7 +211,7 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht } // Check if the modified school code belong to an existing school. - if user, err := school.(*School).exists(); err == nil && user != nil { + if user, err := school.(*School).exists(db); err == nil && user != nil { if user.ID != school.(*School).UserID { // err := setFlashMessage(w, r, "schoolExists") // if err != nil { @@ -225,11 +225,11 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht school.(*School).UserModifierUpdate = NewUserModifierUpdate(r) - _, err = SaveSchool(school) + _, err = SaveSchool(db, school) if err != nil { return nil, err } - school, err = model.Read(args, w, r) + school, err = model.Read(db, args, w, r) if err != nil { return nil, err } @@ -237,35 +237,35 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht } } -func (model *School) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { - school, err := model.Read(args, w, r) +func (model *School) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { + school, err := model.Read(db, args, w, r) if err != nil { return nil, err } - if err := DB().Unscoped().Delete(school.(*School)).Error; err != nil { + if err := db._db.Unscoped().Delete(school.(*School)).Error; err != nil { return nil, err } return school.(*School), nil } -func CreateSchool(school *School) (*School, error) { - if err := DB().Create(school).Error; err != nil { +func CreateSchool(db *Database, school *School) (*School, error) { + if err := db._db.Create(school).Error; err != nil { return nil, err } return school, nil } -func SaveSchool(school interface{}) (interface{}, error) { - if err := DB().Omit("Region").Save(school).Error; err != nil { +func SaveSchool(db *Database, school interface{}) (interface{}, error) { + if err := db._db.Omit("Region").Save(school).Error; err != nil { return nil, err } return school, nil } -func (model *School) HasCategory(participant *Participant) (bool, error) { +func (model *School) HasCategory(db *Database, participant *Participant) (bool, error) { var participants []*Participant - if err := DB().Where("category_id = ? AND school_id = ? AND id <> ?", participant.CategoryID, model.ID, participant.ID).Find(&participants).Error; err != nil { + if err := db._db.Where("category_id = ? AND school_id = ? AND id <> ?", participant.CategoryID, model.ID, participant.ID).Find(&participants).Error; err != nil { return false, err } return len(participants) > 0, nil diff --git a/orm/user.go b/orm/user.go index cd198b6d..ba8f1053 100644 --- a/orm/user.go +++ b/orm/user.go @@ -47,7 +47,7 @@ func (model *User) BeforeSave(tx *gorm.DB) error { return nil } -func (model *User) Create(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *User) Create(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { user := new(User) return user, nil @@ -57,7 +57,7 @@ func (model *User) Create(args map[string]string, w http.ResponseWriter, r *http if err != nil { return nil, err } - user, err = CreateUser(user) + user, err = CreateUser(db, user) if err != nil { return nil, err } @@ -65,36 +65,36 @@ func (model *User) Create(args map[string]string, w http.ResponseWriter, r *http } } -func (model *User) Read(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *User) Read(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var user User id := args["id"] - if err := DB(). /*.Preload("Something")*/ First(&user, id).Error; err != nil { + if err := db._db. /*.Preload("Something")*/ First(&user, id).Error; err != nil { return nil, err } return &user, nil } -func (model *User) ReadAll(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *User) ReadAll(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { var users []*User - if err := DB(). /*.Preload("Something")*/ Order("created_at").Find(&users).Error; err != nil { + if err := db._db. /*.Preload("Something")*/ Order("created_at").Find(&users).Error; err != nil { return nil, err } return users, nil } -func (model *User) Update(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { +func (model *User) Update(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { if r.Method == "GET" { - result, err := model.Read(args, w, r) + result, err := model.Read(db, args, w, r) if err != nil { return nil, err } user := result.(*User) - // if err := DB().Find(&user.AllElements).Error; err != nil { + // if err := db._db.Find(&user.AllElements).Error; err != nil { // return nil, err // } @@ -103,7 +103,7 @@ func (model *User) Update(args map[string]string, w http.ResponseWriter, r *http return user, nil } else { - user, err := model.Read(args, w, r) + user, err := model.Read(db, args, w, r) if err != nil { return nil, err } @@ -111,11 +111,11 @@ func (model *User) Update(args map[string]string, w http.ResponseWriter, r *http if err != nil { return nil, err } - _, err = SaveUser(user) + _, err = SaveUser(db, user) if err != nil { return nil, err } - user, err = model.Read(args, w, r) + user, err = model.Read(db, args, w, r) if err != nil { return nil, err } @@ -123,26 +123,26 @@ func (model *User) Update(args map[string]string, w http.ResponseWriter, r *http } } -func (model *User) Delete(args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { - user, err := model.Read(args, w, r) +func (model *User) Delete(db *Database, args map[string]string, w http.ResponseWriter, r *http.Request) (interface{}, error) { + user, err := model.Read(db, args, w, r) if err != nil { return nil, err } - if err := DB().Unscoped().Delete(user.(*User)).Error; err != nil { + if err := db._db.Unscoped().Delete(user.(*User)).Error; err != nil { return nil, err } return user.(*User), nil } -func CreateUser(user *User) (*User, error) { - if err := DB().Create(user).Error; err != nil { +func CreateUser(db *Database, user *User) (*User, error) { + if err := db._db.Create(user).Error; err != nil { return nil, err } return user, nil } -func SaveUser(user interface{}) (interface{}, error) { - if err := DB(). /*.Omit("Something")*/ Save(user).Error; err != nil { +func SaveUser(db *Database, user interface{}) (interface{}, error) { + if err := db._db. /*.Omit("Something")*/ Save(user).Error; err != nil { return nil, err } return user, nil diff --git a/orm/useraction.go b/orm/useraction.go index c0fb0ad6..867598e8 100644 --- a/orm/useraction.go +++ b/orm/useraction.go @@ -39,20 +39,20 @@ func NewUserModifierCreate(r *http.Request) *UserModifierCreate { } } -func (um *UserModifierCreate) CreatedBy() (*UserAction, error) { +func (um *UserModifierCreate) CreatedBy(db *Database) (*UserAction, error) { action := new(UserAction) switch um.CreatorRole { case "participant": var participant Participant - if err := DB().Preload("User").First(&participant, um.CreatorID).Error; err != nil { + if err := db._db.Preload("User").First(&participant, um.CreatorID).Error; err != nil { return nil, err } action.User = *participant.User case "school": var school School - if err := DB().Preload("User").First(&school, um.CreatorID).Error; err != nil { + if err := db._db.Preload("User").First(&school, um.CreatorID).Error; err != nil { return nil, err } action.User = *school.User @@ -80,19 +80,19 @@ func NewUserModifierUpdate(r *http.Request) *UserModifierUpdate { } } -func (um *UserModifierUpdate) UpdatedBy() (*UserAction, error) { +func (um *UserModifierUpdate) UpdatedBy(db *Database) (*UserAction, error) { action := new(UserAction) switch um.UpdaterRole { case "participant": var participant Participant - if err := DB().Preload("User").First(&participant, um.UpdaterID).Error; err != nil { + if err := db._db.Preload("User").First(&participant, um.UpdaterID).Error; err != nil { return nil, err } action.User = *participant.User case "school": var school School - if err := DB().Preload("User").First(&school, um.UpdaterID).Error; err != nil { + if err := db._db.Preload("User").First(&school, um.UpdaterID).Error; err != nil { return nil, err } action.User = *school.User