Working on db refactoring

This commit is contained in:
Andrea Fazzi 2020-01-14 16:28:27 +01:00
parent 3f61d86465
commit 2ad7e3d180
8 changed files with 180 additions and 165 deletions

View file

@ -17,8 +17,10 @@ import (
"git.andreafazzi.eu/andrea/oef/orm"
"git.andreafazzi.eu/andrea/oef/renderer"
jwtmiddleware "github.com/auth0/go-jwt-middleware"
jwt "github.com/dgrijalva/jwt-go"
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/jinzhu/inflection"
)
@ -31,15 +33,23 @@ type PathPattern struct {
}
type Handlers struct {
Config *config.ConfigT
Models []interface{}
Login func() http.Handler
Logout func() http.Handler
Login func(store *sessions.CookieStore, signingKey []byte) http.Handler
Logout func(store *sessions.CookieStore) http.Handler
Home func() http.Handler
GetToken func() http.Handler
GetToken func(signingKey []byte) http.Handler
Static func() http.Handler
Recover func(next http.Handler) http.Handler
CookieStore *sessions.CookieStore
JWTSigningKey []byte
JWTCookieMiddleware *jwtmiddleware.JWTMiddleware
JWTHeaderMiddleware *jwtmiddleware.JWTMiddleware
Router *mux.Router
}
@ -101,7 +111,7 @@ func (h *Handlers) generateModelHandlers(r *mux.Router, model interface{}) {
pattern.Path(
pluralizedModelName(model),
),
jwtCookie.Handler(
h.JWTCookieMiddleware.Handler(
h.Recover(
h.modelHandler(
pluralizedModelName(model),
@ -115,7 +125,7 @@ func (h *Handlers) generateModelHandlers(r *mux.Router, model interface{}) {
r.Handle(pattern.Path(
pluralizedModelName(model),
),
jwtHeader.Handler(
h.JWTHeaderMiddleware.Handler(
h.Recover(
h.modelHandler(
pluralizedModelName(model),
@ -154,30 +164,49 @@ func (h *Handlers) generateModelHandlers(r *mux.Router, model interface{}) {
}
func NewHandlers(models []interface{}) *Handlers {
func NewHandlers(config *config.ConfigT, models []interface{}) *Handlers {
handlers := new(Handlers)
handlers.Config = config
handlers.CookieStore = sessions.NewCookieStore([]byte(config.Keys.CookieStoreKey))
handlers.Login = DefaultLoginHandler
handlers.Logout = DefaultLogoutHandler
handlers.Recover = DefaultRecoverHandler
handlers.Home = DefaultHomeHandler
handlers.GetToken = DefaultGetTokenHandler
handlers.JWTCookieMiddleware = jwtmiddleware.New(jwtmiddleware.Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
return []byte(config.Keys.JWTSigningKey), nil
},
SigningMethod: jwt.SigningMethodHS256,
Extractor: handlers.cookieExtractor,
ErrorHandler: handlers.onError,
})
handlers.JWTHeaderMiddleware = jwtmiddleware.New(jwtmiddleware.Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
return config.Keys.JWTSigningKey, nil
},
SigningMethod: jwt.SigningMethodHS256,
})
r := mux.NewRouter()
// Authentication
r.Handle("/login", handlers.Login())
r.Handle("/logout", handlers.Logout())
r.Handle("/login", handlers.Login(handlers.CookieStore, []byte(config.Keys.JWTSigningKey)))
r.Handle("/logout", handlers.Logout(handlers.CookieStore))
// School subscription
r.Handle("/subscribe", handlers.Login())
r.Handle("/subscribe", handlers.Login(handlers.CookieStore, []byte(config.Keys.JWTSigningKey)))
// Home
r.Handle("/", jwtCookie.Handler(handlers.Recover(handlers.Home())))
r.Handle("/", handlers.JWTCookieMiddleware.Handler(handlers.Recover(handlers.Home())))
// Generate CRUD handlers
@ -187,16 +216,19 @@ func NewHandlers(models []interface{}) *Handlers {
// Token handling
r.Handle("/get_token", handlers.GetToken())
r.Handle("/get_token", handlers.GetToken([]byte(config.Keys.JWTSigningKey)))
// Static file server
r.PathPrefix("/").Handler(http.FileServer(http.Dir("./dist/")))
return r
handlers.Router = r
return handlers
}
func onError(w http.ResponseWriter, r *http.Request, err string) {
func (h *Handlers) onError(w http.ResponseWriter, r *http.Request, err string) {
log.Print(err)
http.Redirect(w, r, "/login?tpl_layout=login&tpl_content=login", http.StatusTemporaryRedirect)
}
@ -209,8 +241,8 @@ func respondWithStaticFile(w http.ResponseWriter, filename string) error {
return nil
}
func fromCookie(r *http.Request) (string, error) {
session, err := store.Get(r, "login-session")
func (h *Handlers) cookieExtractor(r *http.Request) (string, error) {
session, err := h.CookieStore.Get(r, "login-session")
if err != nil {
return "", nil
}
@ -238,12 +270,12 @@ func DefaultRecoverHandler(next http.Handler) http.Handler {
return http.HandlerFunc(fn)
}
func setFlashMessage(w http.ResponseWriter, r *http.Request, key string) error {
session, err := store.Get(r, "flash-session")
func (h *Handlers) setFlashMessage(w http.ResponseWriter, r *http.Request, key string) error {
session, err := h.CookieStore.Get(r, "flash-session")
if err != nil {
return err
}
session.AddFlash(i18n.FlashMessages[key][config.Config.Language])
session.AddFlash(i18n.FlashMessages[key][h.Config.Language])
err = session.Save(r, w)
if err != nil {
return err
@ -259,7 +291,7 @@ func hasPermission(role, path string) bool {
return permissions[role][path]
}
func get(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) {
func (h *Handlers) get(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) {
format := r.URL.Query().Get("format")
getFn, err := orm.GetFunc(pattern.Path(model))
if err != nil {
@ -270,7 +302,7 @@ func get(w http.ResponseWriter, r *http.Request, model string, pattern PathPatte
role := claims["role"].(string)
if !hasPermission(role, pattern.Path(model)) {
log.Println("ERRORE")
setFlashMessage(w, r, "notAuthorized")
h.setFlashMessage(w, r, "notAuthorized")
renderer.Render[format](w, r, fmt.Errorf("%s", "Errore di autorizzazione"))
} else {
@ -367,7 +399,7 @@ func (h *Handlers) modelHandler(model string, pattern PathPattern) http.Handler
switch r.Method {
case "GET":
get(w, r, model, pattern)
h.get(w, r, model, pattern)
case "POST":
post(w, r, model, pattern)

View file

@ -19,7 +19,8 @@ import (
)
var (
token string
token string
handlers *Handlers
)
// Start of setup
@ -28,6 +29,43 @@ type testSuite struct {
prettytest.Suite
}
func authenticate(request *http.Request, tokenString string, signingKey string) (*http.Request, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return []byte(signingKey), nil
})
if err != nil {
return nil, err
}
ctx := request.Context()
ctx = context.WithValue(ctx, "user", token)
return request.WithContext(ctx), nil
}
func requestToken(handlers *Handlers) string {
req, err := http.NewRequest("GET", "/get_token", nil)
if err != nil {
panic(err)
}
req.SetBasicAuth("admin", "admin")
rr := httptest.NewRecorder()
handlers.GetToken([]byte(config.Config.Keys.JWTSigningKey)).ServeHTTP(rr, req)
var data struct {
Token string
}
if err := json.Unmarshal(rr.Body.Bytes(), &data); err != nil {
panic(err)
}
return data.Token
}
func TestRunner(t *testing.T) {
prettytest.Run(
t,
@ -81,30 +119,12 @@ func (t *testSuite) BeforeAll() {
panic(err)
}
jsonRenderer, err := renderer.NewJSONRenderer()
if err != nil {
panic(err)
}
csvRenderer, err := renderer.NewCSVRenderer()
if err != nil {
panic(err)
}
renderer.Render = make(map[string]func(http.ResponseWriter, *http.Request, interface{}, ...url.Values))
renderer.Render["html"] = func(w http.ResponseWriter, r *http.Request, data interface{}, options ...url.Values) {
htmlRenderer.Render(w, r, data, options...)
}
renderer.Render["json"] = func(w http.ResponseWriter, r *http.Request, data interface{}, options ...url.Values) {
jsonRenderer.Render(w, r, data, options...)
}
renderer.Render["csv"] = func(w http.ResponseWriter, r *http.Request, data interface{}, options ...url.Values) {
csvRenderer.Render(w, r, data, options...)
}
// Load the configuration
err = config.ReadFile("testdata/config.yaml", config.Config)
@ -112,30 +132,9 @@ func (t *testSuite) BeforeAll() {
panic(err)
}
config.Config.LogLevel = config.LOG_LEVEL_OFF
handlers = NewHandlers(config.Config, models)
token = requestToken(handlers)
req, err := http.NewRequest("GET", "/get_token", nil)
if err != nil {
panic(err)
}
req.SetBasicAuth("admin", "admin")
rr := httptest.NewRecorder()
signingKey = []byte("secret")
tokenHandler().ServeHTTP(rr, req)
var data struct {
Token string
UserID string
}
if err := json.Unmarshal(rr.Body.Bytes(), &data); err != nil {
panic(err)
}
token = data.Token
Handlers(models)
if err := orm.MapHandlers(models); err != nil {
panic(err)
}
@ -156,22 +155,26 @@ func (t *testSuite) TestReadAllContests() {
rr := httptest.NewRecorder()
tkn, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
return []byte("secret"), nil
})
req, err = authenticate(req, token, config.Config.Keys.JWTSigningKey)
// tkn, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
// return []byte("secret"), nil
// })
// if err != nil {
// panic(err)
// }
// ctx := req.Context()
// ctx = context.WithValue(ctx, "user", tkn)
// req = req.WithContext(ctx)
t.Nil(err)
if err != nil {
panic(err)
}
ctx := req.Context()
ctx = context.WithValue(ctx, "user", tkn)
req = req.WithContext(ctx)
handlers.modelHandler("contests", pattern).ServeHTTP(rr, req)
modelHandler("contests", pattern).ServeHTTP(rr, req)
t.Equal(http.StatusOK, rr.Code)
t.Equal(http.StatusOK, rr.Code)
if !t.Failed() {
t.True(strings.Contains(rr.Body.String(), "JUNIOR Contest"))
if !t.Failed() {
t.True(strings.Contains(rr.Body.String(), "JUNIOR Contest"))
}
}
}

View file

@ -10,7 +10,6 @@ import (
"git.andreafazzi.eu/andrea/oef/config"
"git.andreafazzi.eu/andrea/oef/orm"
"git.andreafazzi.eu/andrea/oef/renderer"
jwtmiddleware "github.com/auth0/go-jwt-middleware"
jwt "github.com/dgrijalva/jwt-go"
"github.com/gorilla/sessions"
)
@ -23,27 +22,27 @@ type UserToken struct {
}
var (
signingKey = []byte(config.Config.Keys.JWTSigningKey)
store = sessions.NewCookieStore([]byte(config.Config.Keys.CookieStoreKey))
// signingKey = []byte(config.Config.Keys.JWTSigningKey)
// store = sessions.NewCookieStore([]byte(config.Config.Keys.CookieStoreKey))
jwtCookie = jwtmiddleware.New(jwtmiddleware.Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
return signingKey, nil
},
SigningMethod: jwt.SigningMethodHS256,
Extractor: fromCookie,
ErrorHandler: onError,
})
// jwtCookie = jwtmiddleware.New(jwtmiddleware.Options{
// ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
// return signingKey, nil
// },
// SigningMethod: jwt.SigningMethodHS256,
// Extractor: fromCookie,
// ErrorHandler: onError,
// })
jwtHeader = jwtmiddleware.New(jwtmiddleware.Options{
ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
return signingKey, nil
},
SigningMethod: jwt.SigningMethodHS256,
})
// jwtHeader = jwtmiddleware.New(jwtmiddleware.Options{
// ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) {
// return signingKey, nil
// },
// SigningMethod: jwt.SigningMethodHS256,
// })
)
func DefaultLogoutHandler() http.Handler {
func DefaultLogoutHandler(store *sessions.CookieStore) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
session, err := store.Get(r, "login-session")
if err != nil {
@ -60,14 +59,14 @@ func DefaultLogoutHandler() http.Handler {
return http.HandlerFunc(fn)
}
func DefaultLoginHandler() http.Handler {
func DefaultLoginHandler(store *sessions.CookieStore, signingKey []byte) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
renderer.Render["html"](w, r, nil, r.URL.Query())
}
if r.Method == "POST" {
r.ParseForm()
token, err := getToken(r.FormValue("username"), r.FormValue("password"))
token, err := getToken(r.FormValue("username"), r.FormValue("password"), signingKey)
if err != nil {
http.Redirect(w, r, "/login?tpl_layout=login&tpl_content=login&failed=true", http.StatusSeeOther)
} else {
@ -128,9 +127,8 @@ func checkCredential(username string, password string) (*UserToken, error) {
}
// FIXME: Refactor the functions above please!!!
func getToken(username string, password string) ([]byte, error) {
func getToken(username string, password string, signingKey []byte) ([]byte, error) {
user, err := checkCredential(username, password)
if err != nil {
return nil, err
}
@ -156,34 +154,18 @@ func getToken(username string, password string) ([]byte, error) {
}
// FIXME: Refactor the functions above please!!!
func DefaultGetTokenHandler() http.Handler {
func DefaultGetTokenHandler(signingKey []byte) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
username, password, _ := r.BasicAuth()
user, err := checkCredential(username, password)
if err != nil {
panic(err)
}
/* Set token claims */
claims := make(map[string]interface{})
claims["admin"] = user.Admin
claims["username"] = user.Username
claims["role"] = user.Role
claims["user_id"] = user.UserID
claims["exp"] = time.Now().Add(time.Hour * 24).Unix()
/* Create the token */
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims(claims))
/* Sign the token with our secret */
tokenString, err := token.SignedString(signingKey)
token, err := getToken(username, password, signingKey)
if err != nil {
panic(err)
}
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.Write([]byte(fmt.Sprintf("{\"Token\":\"%s\",\"User\":\"%s\"}", tokenString, user.Username)))
w.Write([]byte(fmt.Sprintf("{\"Token\":\"%s\"}", string(token))))
}
return http.HandlerFunc(fn)
}

View file

@ -84,7 +84,7 @@ func main() {
}
log.Println("OEF is listening to port 3000...")
if err := http.ListenAndServe(":3000", handlers.LoggingHandler(os.Stdout, oef_handlers.Handlers(models))); err != nil {
if err := http.ListenAndServe(":3000", handlers.LoggingHandler(os.Stdout, oef_handlers.NewHandlers(config.Config, models).Router)); err != nil {
panic(err)
}

View file

@ -8,7 +8,6 @@ import (
"strings"
"git.andreafazzi.eu/andrea/oef/config"
"github.com/gorilla/sessions"
"github.com/jinzhu/gorm"
"github.com/jinzhu/inflection"
@ -21,26 +20,26 @@ type IDer interface {
type GetFn func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)
var (
fns map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)
currDB *gorm.DB
store = sessions.NewCookieStore([]byte(config.Config.Keys.CookieStoreKey))
)
type Database struct {
Config *config.ConfigT
func init() {
fns = make(map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error), 0)
_db *gorm.DB
fns map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error)
}
func New(connection string) (*gorm.DB, error) {
db, err := gorm.Open("mysql", connection)
func NewDatabase(config *config.ConfigT) (*Database, error) {
db := new(Database)
db.fns = make(db*gorm.DB, map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error), 0)
db._db, err := gorm.Open("mysql", fmt.Sprintf("%s?%s", config.Config.Orm.Connection, config.Config.Orm.Options))
if err != nil {
return nil, err
}
return db, nil
return db
}
func AutoMigrate(models ...interface{}) {
if err := currDB.AutoMigrate(models...).Error; err != nil {
func (db *Database) AutoMigrate(models ...interface{}) {
if err := db._db.AutoMigrate(models...).Error; err != nil {
panic(err)
}

View file

@ -7,9 +7,7 @@ import (
"strconv"
"strings"
"git.andreafazzi.eu/andrea/oef/config"
"git.andreafazzi.eu/andrea/oef/errors"
"git.andreafazzi.eu/andrea/oef/i18n"
"git.andreafazzi.eu/andrea/oef/renderer"
"github.com/jinzhu/gorm"
)
@ -81,18 +79,18 @@ func (model *Participant) String() string {
return fmt.Sprintf("%s %s", strings.ToUpper(model.Lastname), strings.Title(strings.ToLower(model.Firstname)))
}
func setFlashMessage(w http.ResponseWriter, r *http.Request, key string) error {
session, err := store.Get(r, "flash-session")
if err != nil {
return err
}
session.AddFlash(i18n.FlashMessages[key][config.Config.Language])
err = session.Save(r, w)
if err != nil {
return err
}
return nil
}
// func setFlashMessage(w http.ResponseWriter, r *http.Request, key string) error {
// session, err := store.Get(r, "flash-session")
// if err != nil {
// return err
// }
// session.AddFlash(i18n.FlashMessages[key][config.Config.Language])
// err = session.Save(r, w)
// if err != nil {
// return err
// }
// return nil
// }
func (model *Participant) exists() (*User, error) {
var user User
@ -181,10 +179,10 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter,
if err := DB().Where("user_id = ?", user.ID).Find(&participant).Error; err != nil {
return nil, err
}
err := setFlashMessage(w, r, "participantExists")
if err != nil {
return nil, err
}
// err := setFlashMessage(w, r, "participantExists")
// if err != nil {
// return nil, err
// }
return participant, nil
} else if err != nil {
return nil, err
@ -247,7 +245,7 @@ func (model *Participant) Read(args map[string]string, w http.ResponseWriter, r
}
if strconv.Itoa(int(participant.SchoolID)) != getUserIDFromToken(r) {
setFlashMessage(w, r, "notAuthorized")
// setFlashMessage(w, r, "notAuthorized")
return nil, errors.NotAuthorized
}
@ -337,10 +335,10 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter,
if user, err := participant.(*Participant).exists(); err == nil && user != nil {
if user.ID != participant.(*Participant).UserID {
err := setFlashMessage(w, r, "participantExists")
if err != nil {
return nil, err
}
// err := setFlashMessage(w, r, "participantExists")
// if err != nil {
// return nil, err
// }
return participant, nil
}
} else if err != nil {

View file

@ -139,10 +139,10 @@ func (model *School) Create(args map[string]string, w http.ResponseWriter, r *ht
if err := DB().Where("user_id = ?", user.ID).Find(&school).Error; err != nil {
return nil, err
}
err := setFlashMessage(w, r, "schoolExists")
if err != nil {
return nil, err
}
// err := setFlashMessage(w, r, "schoolExists")
// if err != nil {
// return nil, err
// }
return school, nil
} else if err != nil {
return nil, err
@ -164,7 +164,7 @@ func (model *School) Read(args map[string]string, w http.ResponseWriter, r *http
id := args["id"]
if isSchool(r) && id != getUserIDFromToken(r) {
setFlashMessage(w, r, "notAuthorized")
// setFlashMessage(w, r, "notAuthorized")
return nil, errors.NotAuthorized
}
@ -213,10 +213,10 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht
// Check if the modified school code belong to an existing school.
if user, err := school.(*School).exists(); err == nil && user != nil {
if user.ID != school.(*School).UserID {
err := setFlashMessage(w, r, "schoolExists")
if err != nil {
return nil, err
}
// err := setFlashMessage(w, r, "schoolExists")
// if err != nil {
// return nil, err
// }
return school, nil
}
} else if err != nil {

View file

@ -54,6 +54,7 @@ var (
contentTypeToFormat = map[string]string{
"application/x-www-form-urlencoded": "html",
"text/html; charset=utf-8": "html",
"application/json": "json",
"text/csv; charset=utf-8": "csv",
"application/pdf": "pdf",