diff --git a/handlers/handlers.go b/handlers/handlers.go index ac0b588c..78110f5f 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -17,8 +17,10 @@ import ( "git.andreafazzi.eu/andrea/oef/orm" "git.andreafazzi.eu/andrea/oef/renderer" + jwtmiddleware "github.com/auth0/go-jwt-middleware" jwt "github.com/dgrijalva/jwt-go" "github.com/gorilla/mux" + "github.com/gorilla/sessions" "github.com/jinzhu/inflection" ) @@ -31,15 +33,23 @@ type PathPattern struct { } type Handlers struct { + Config *config.ConfigT + Models []interface{} - Login func() http.Handler - Logout func() http.Handler + Login func(store *sessions.CookieStore, signingKey []byte) http.Handler + Logout func(store *sessions.CookieStore) http.Handler Home func() http.Handler - GetToken func() http.Handler + GetToken func(signingKey []byte) http.Handler Static func() http.Handler Recover func(next http.Handler) http.Handler + CookieStore *sessions.CookieStore + + JWTSigningKey []byte + JWTCookieMiddleware *jwtmiddleware.JWTMiddleware + JWTHeaderMiddleware *jwtmiddleware.JWTMiddleware + Router *mux.Router } @@ -101,7 +111,7 @@ func (h *Handlers) generateModelHandlers(r *mux.Router, model interface{}) { pattern.Path( pluralizedModelName(model), ), - jwtCookie.Handler( + h.JWTCookieMiddleware.Handler( h.Recover( h.modelHandler( pluralizedModelName(model), @@ -115,7 +125,7 @@ func (h *Handlers) generateModelHandlers(r *mux.Router, model interface{}) { r.Handle(pattern.Path( pluralizedModelName(model), ), - jwtHeader.Handler( + h.JWTHeaderMiddleware.Handler( h.Recover( h.modelHandler( pluralizedModelName(model), @@ -154,30 +164,49 @@ func (h *Handlers) generateModelHandlers(r *mux.Router, model interface{}) { } -func NewHandlers(models []interface{}) *Handlers { - +func NewHandlers(config *config.ConfigT, models []interface{}) *Handlers { handlers := new(Handlers) + handlers.Config = config + + handlers.CookieStore = sessions.NewCookieStore([]byte(config.Keys.CookieStoreKey)) + handlers.Login = DefaultLoginHandler handlers.Logout = DefaultLogoutHandler handlers.Recover = DefaultRecoverHandler handlers.Home = DefaultHomeHandler handlers.GetToken = DefaultGetTokenHandler + handlers.JWTCookieMiddleware = jwtmiddleware.New(jwtmiddleware.Options{ + ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { + return []byte(config.Keys.JWTSigningKey), nil + }, + SigningMethod: jwt.SigningMethodHS256, + Extractor: handlers.cookieExtractor, + ErrorHandler: handlers.onError, + }) + + handlers.JWTHeaderMiddleware = jwtmiddleware.New(jwtmiddleware.Options{ + ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { + return config.Keys.JWTSigningKey, nil + }, + SigningMethod: jwt.SigningMethodHS256, + }) + r := mux.NewRouter() // Authentication - r.Handle("/login", handlers.Login()) - r.Handle("/logout", handlers.Logout()) + r.Handle("/login", handlers.Login(handlers.CookieStore, []byte(config.Keys.JWTSigningKey))) + r.Handle("/logout", handlers.Logout(handlers.CookieStore)) // School subscription - r.Handle("/subscribe", handlers.Login()) + r.Handle("/subscribe", handlers.Login(handlers.CookieStore, []byte(config.Keys.JWTSigningKey))) // Home - r.Handle("/", jwtCookie.Handler(handlers.Recover(handlers.Home()))) + r.Handle("/", handlers.JWTCookieMiddleware.Handler(handlers.Recover(handlers.Home()))) // Generate CRUD handlers @@ -187,16 +216,19 @@ func NewHandlers(models []interface{}) *Handlers { // Token handling - r.Handle("/get_token", handlers.GetToken()) + r.Handle("/get_token", handlers.GetToken([]byte(config.Keys.JWTSigningKey))) // Static file server r.PathPrefix("/").Handler(http.FileServer(http.Dir("./dist/"))) - return r + handlers.Router = r + + return handlers } -func onError(w http.ResponseWriter, r *http.Request, err string) { +func (h *Handlers) onError(w http.ResponseWriter, r *http.Request, err string) { + log.Print(err) http.Redirect(w, r, "/login?tpl_layout=login&tpl_content=login", http.StatusTemporaryRedirect) } @@ -209,8 +241,8 @@ func respondWithStaticFile(w http.ResponseWriter, filename string) error { return nil } -func fromCookie(r *http.Request) (string, error) { - session, err := store.Get(r, "login-session") +func (h *Handlers) cookieExtractor(r *http.Request) (string, error) { + session, err := h.CookieStore.Get(r, "login-session") if err != nil { return "", nil } @@ -238,12 +270,12 @@ func DefaultRecoverHandler(next http.Handler) http.Handler { return http.HandlerFunc(fn) } -func setFlashMessage(w http.ResponseWriter, r *http.Request, key string) error { - session, err := store.Get(r, "flash-session") +func (h *Handlers) setFlashMessage(w http.ResponseWriter, r *http.Request, key string) error { + session, err := h.CookieStore.Get(r, "flash-session") if err != nil { return err } - session.AddFlash(i18n.FlashMessages[key][config.Config.Language]) + session.AddFlash(i18n.FlashMessages[key][h.Config.Language]) err = session.Save(r, w) if err != nil { return err @@ -259,7 +291,7 @@ func hasPermission(role, path string) bool { return permissions[role][path] } -func get(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) { +func (h *Handlers) get(w http.ResponseWriter, r *http.Request, model string, pattern PathPattern) { format := r.URL.Query().Get("format") getFn, err := orm.GetFunc(pattern.Path(model)) if err != nil { @@ -270,7 +302,7 @@ func get(w http.ResponseWriter, r *http.Request, model string, pattern PathPatte role := claims["role"].(string) if !hasPermission(role, pattern.Path(model)) { log.Println("ERRORE") - setFlashMessage(w, r, "notAuthorized") + h.setFlashMessage(w, r, "notAuthorized") renderer.Render[format](w, r, fmt.Errorf("%s", "Errore di autorizzazione")) } else { @@ -367,7 +399,7 @@ func (h *Handlers) modelHandler(model string, pattern PathPattern) http.Handler switch r.Method { case "GET": - get(w, r, model, pattern) + h.get(w, r, model, pattern) case "POST": post(w, r, model, pattern) diff --git a/handlers/handlers_test.go b/handlers/handlers_test.go index 4ddd6d94..19374654 100644 --- a/handlers/handlers_test.go +++ b/handlers/handlers_test.go @@ -19,7 +19,8 @@ import ( ) var ( - token string + token string + handlers *Handlers ) // Start of setup @@ -28,6 +29,43 @@ type testSuite struct { prettytest.Suite } +func authenticate(request *http.Request, tokenString string, signingKey string) (*http.Request, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + return []byte(signingKey), nil + }) + if err != nil { + return nil, err + } + ctx := request.Context() + ctx = context.WithValue(ctx, "user", token) + return request.WithContext(ctx), nil + +} + +func requestToken(handlers *Handlers) string { + req, err := http.NewRequest("GET", "/get_token", nil) + if err != nil { + panic(err) + } + + req.SetBasicAuth("admin", "admin") + + rr := httptest.NewRecorder() + + handlers.GetToken([]byte(config.Config.Keys.JWTSigningKey)).ServeHTTP(rr, req) + + var data struct { + Token string + } + + if err := json.Unmarshal(rr.Body.Bytes(), &data); err != nil { + panic(err) + } + + return data.Token + +} + func TestRunner(t *testing.T) { prettytest.Run( t, @@ -81,30 +119,12 @@ func (t *testSuite) BeforeAll() { panic(err) } - jsonRenderer, err := renderer.NewJSONRenderer() - if err != nil { - panic(err) - } - - csvRenderer, err := renderer.NewCSVRenderer() - if err != nil { - panic(err) - } - renderer.Render = make(map[string]func(http.ResponseWriter, *http.Request, interface{}, ...url.Values)) renderer.Render["html"] = func(w http.ResponseWriter, r *http.Request, data interface{}, options ...url.Values) { htmlRenderer.Render(w, r, data, options...) } - renderer.Render["json"] = func(w http.ResponseWriter, r *http.Request, data interface{}, options ...url.Values) { - jsonRenderer.Render(w, r, data, options...) - } - - renderer.Render["csv"] = func(w http.ResponseWriter, r *http.Request, data interface{}, options ...url.Values) { - csvRenderer.Render(w, r, data, options...) - } - // Load the configuration err = config.ReadFile("testdata/config.yaml", config.Config) @@ -112,30 +132,9 @@ func (t *testSuite) BeforeAll() { panic(err) } config.Config.LogLevel = config.LOG_LEVEL_OFF + handlers = NewHandlers(config.Config, models) + token = requestToken(handlers) - req, err := http.NewRequest("GET", "/get_token", nil) - if err != nil { - panic(err) - } - - req.SetBasicAuth("admin", "admin") - - rr := httptest.NewRecorder() - signingKey = []byte("secret") - tokenHandler().ServeHTTP(rr, req) - - var data struct { - Token string - UserID string - } - - if err := json.Unmarshal(rr.Body.Bytes(), &data); err != nil { - panic(err) - } - - token = data.Token - - Handlers(models) if err := orm.MapHandlers(models); err != nil { panic(err) } @@ -156,22 +155,26 @@ func (t *testSuite) TestReadAllContests() { rr := httptest.NewRecorder() - tkn, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) { - return []byte("secret"), nil - }) + req, err = authenticate(req, token, config.Config.Keys.JWTSigningKey) + // tkn, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) { + // return []byte("secret"), nil + // }) + // if err != nil { + // panic(err) + // } + // ctx := req.Context() + // ctx = context.WithValue(ctx, "user", tkn) + // req = req.WithContext(ctx) + t.Nil(err) + if err != nil { - panic(err) - } - ctx := req.Context() - ctx = context.WithValue(ctx, "user", tkn) - req = req.WithContext(ctx) + handlers.modelHandler("contests", pattern).ServeHTTP(rr, req) - modelHandler("contests", pattern).ServeHTTP(rr, req) + t.Equal(http.StatusOK, rr.Code) - t.Equal(http.StatusOK, rr.Code) - - if !t.Failed() { - t.True(strings.Contains(rr.Body.String(), "JUNIOR Contest")) + if !t.Failed() { + t.True(strings.Contains(rr.Body.String(), "JUNIOR Contest")) + } } } diff --git a/handlers/login.go b/handlers/login.go index 2fdd232a..f21e7d8e 100644 --- a/handlers/login.go +++ b/handlers/login.go @@ -10,7 +10,6 @@ import ( "git.andreafazzi.eu/andrea/oef/config" "git.andreafazzi.eu/andrea/oef/orm" "git.andreafazzi.eu/andrea/oef/renderer" - jwtmiddleware "github.com/auth0/go-jwt-middleware" jwt "github.com/dgrijalva/jwt-go" "github.com/gorilla/sessions" ) @@ -23,27 +22,27 @@ type UserToken struct { } var ( - signingKey = []byte(config.Config.Keys.JWTSigningKey) - store = sessions.NewCookieStore([]byte(config.Config.Keys.CookieStoreKey)) +// signingKey = []byte(config.Config.Keys.JWTSigningKey) +// store = sessions.NewCookieStore([]byte(config.Config.Keys.CookieStoreKey)) - jwtCookie = jwtmiddleware.New(jwtmiddleware.Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - return signingKey, nil - }, - SigningMethod: jwt.SigningMethodHS256, - Extractor: fromCookie, - ErrorHandler: onError, - }) +// jwtCookie = jwtmiddleware.New(jwtmiddleware.Options{ +// ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { +// return signingKey, nil +// }, +// SigningMethod: jwt.SigningMethodHS256, +// Extractor: fromCookie, +// ErrorHandler: onError, +// }) - jwtHeader = jwtmiddleware.New(jwtmiddleware.Options{ - ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { - return signingKey, nil - }, - SigningMethod: jwt.SigningMethodHS256, - }) +// jwtHeader = jwtmiddleware.New(jwtmiddleware.Options{ +// ValidationKeyGetter: func(token *jwt.Token) (interface{}, error) { +// return signingKey, nil +// }, +// SigningMethod: jwt.SigningMethodHS256, +// }) ) -func DefaultLogoutHandler() http.Handler { +func DefaultLogoutHandler(store *sessions.CookieStore) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { session, err := store.Get(r, "login-session") if err != nil { @@ -60,14 +59,14 @@ func DefaultLogoutHandler() http.Handler { return http.HandlerFunc(fn) } -func DefaultLoginHandler() http.Handler { +func DefaultLoginHandler(store *sessions.CookieStore, signingKey []byte) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { if r.Method == "GET" { renderer.Render["html"](w, r, nil, r.URL.Query()) } if r.Method == "POST" { r.ParseForm() - token, err := getToken(r.FormValue("username"), r.FormValue("password")) + token, err := getToken(r.FormValue("username"), r.FormValue("password"), signingKey) if err != nil { http.Redirect(w, r, "/login?tpl_layout=login&tpl_content=login&failed=true", http.StatusSeeOther) } else { @@ -128,9 +127,8 @@ func checkCredential(username string, password string) (*UserToken, error) { } // FIXME: Refactor the functions above please!!! -func getToken(username string, password string) ([]byte, error) { +func getToken(username string, password string, signingKey []byte) ([]byte, error) { user, err := checkCredential(username, password) - if err != nil { return nil, err } @@ -156,34 +154,18 @@ func getToken(username string, password string) ([]byte, error) { } // FIXME: Refactor the functions above please!!! -func DefaultGetTokenHandler() http.Handler { +func DefaultGetTokenHandler(signingKey []byte) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { username, password, _ := r.BasicAuth() - user, err := checkCredential(username, password) - if err != nil { - panic(err) - } - - /* Set token claims */ - claims := make(map[string]interface{}) - claims["admin"] = user.Admin - claims["username"] = user.Username - claims["role"] = user.Role - claims["user_id"] = user.UserID - claims["exp"] = time.Now().Add(time.Hour * 24).Unix() - - /* Create the token */ - token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims(claims)) - - /* Sign the token with our secret */ - tokenString, err := token.SignedString(signingKey) + token, err := getToken(username, password, signingKey) if err != nil { panic(err) } w.Header().Set("Content-Type", "application/json; charset=utf-8") - w.Write([]byte(fmt.Sprintf("{\"Token\":\"%s\",\"User\":\"%s\"}", tokenString, user.Username))) + w.Write([]byte(fmt.Sprintf("{\"Token\":\"%s\"}", string(token)))) + } return http.HandlerFunc(fn) } diff --git a/main.go b/main.go index a375ee21..33cbc239 100644 --- a/main.go +++ b/main.go @@ -84,7 +84,7 @@ func main() { } log.Println("OEF is listening to port 3000...") - if err := http.ListenAndServe(":3000", handlers.LoggingHandler(os.Stdout, oef_handlers.Handlers(models))); err != nil { + if err := http.ListenAndServe(":3000", handlers.LoggingHandler(os.Stdout, oef_handlers.NewHandlers(config.Config, models).Router)); err != nil { panic(err) } diff --git a/orm/orm.go b/orm/orm.go index ff491086..a1c4e251 100644 --- a/orm/orm.go +++ b/orm/orm.go @@ -8,7 +8,6 @@ import ( "strings" "git.andreafazzi.eu/andrea/oef/config" - "github.com/gorilla/sessions" "github.com/jinzhu/gorm" "github.com/jinzhu/inflection" @@ -21,26 +20,26 @@ type IDer interface { type GetFn func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error) -var ( - fns map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error) - currDB *gorm.DB - store = sessions.NewCookieStore([]byte(config.Config.Keys.CookieStoreKey)) -) +type Database struct { + Config *config.ConfigT -func init() { - fns = make(map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error), 0) + _db *gorm.DB + fns map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error) } -func New(connection string) (*gorm.DB, error) { - db, err := gorm.Open("mysql", connection) +func NewDatabase(config *config.ConfigT) (*Database, error) { + db := new(Database) + + db.fns = make(db*gorm.DB, map[string]func(map[string]string, http.ResponseWriter, *http.Request) (interface{}, error), 0) + db._db, err := gorm.Open("mysql", fmt.Sprintf("%s?%s", config.Config.Orm.Connection, config.Config.Orm.Options)) if err != nil { return nil, err } - return db, nil + return db } -func AutoMigrate(models ...interface{}) { - if err := currDB.AutoMigrate(models...).Error; err != nil { +func (db *Database) AutoMigrate(models ...interface{}) { + if err := db._db.AutoMigrate(models...).Error; err != nil { panic(err) } diff --git a/orm/participant.go b/orm/participant.go index 1903f61e..bec4fb83 100644 --- a/orm/participant.go +++ b/orm/participant.go @@ -7,9 +7,7 @@ import ( "strconv" "strings" - "git.andreafazzi.eu/andrea/oef/config" "git.andreafazzi.eu/andrea/oef/errors" - "git.andreafazzi.eu/andrea/oef/i18n" "git.andreafazzi.eu/andrea/oef/renderer" "github.com/jinzhu/gorm" ) @@ -81,18 +79,18 @@ func (model *Participant) String() string { return fmt.Sprintf("%s %s", strings.ToUpper(model.Lastname), strings.Title(strings.ToLower(model.Firstname))) } -func setFlashMessage(w http.ResponseWriter, r *http.Request, key string) error { - session, err := store.Get(r, "flash-session") - if err != nil { - return err - } - session.AddFlash(i18n.FlashMessages[key][config.Config.Language]) - err = session.Save(r, w) - if err != nil { - return err - } - return nil -} +// func setFlashMessage(w http.ResponseWriter, r *http.Request, key string) error { +// session, err := store.Get(r, "flash-session") +// if err != nil { +// return err +// } +// session.AddFlash(i18n.FlashMessages[key][config.Config.Language]) +// err = session.Save(r, w) +// if err != nil { +// return err +// } +// return nil +// } func (model *Participant) exists() (*User, error) { var user User @@ -181,10 +179,10 @@ func (model *Participant) Create(args map[string]string, w http.ResponseWriter, if err := DB().Where("user_id = ?", user.ID).Find(&participant).Error; err != nil { return nil, err } - err := setFlashMessage(w, r, "participantExists") - if err != nil { - return nil, err - } + // err := setFlashMessage(w, r, "participantExists") + // if err != nil { + // return nil, err + // } return participant, nil } else if err != nil { return nil, err @@ -247,7 +245,7 @@ func (model *Participant) Read(args map[string]string, w http.ResponseWriter, r } if strconv.Itoa(int(participant.SchoolID)) != getUserIDFromToken(r) { - setFlashMessage(w, r, "notAuthorized") + // setFlashMessage(w, r, "notAuthorized") return nil, errors.NotAuthorized } @@ -337,10 +335,10 @@ func (model *Participant) Update(args map[string]string, w http.ResponseWriter, if user, err := participant.(*Participant).exists(); err == nil && user != nil { if user.ID != participant.(*Participant).UserID { - err := setFlashMessage(w, r, "participantExists") - if err != nil { - return nil, err - } + // err := setFlashMessage(w, r, "participantExists") + // if err != nil { + // return nil, err + // } return participant, nil } } else if err != nil { diff --git a/orm/school.go b/orm/school.go index 88b46baa..33d90060 100644 --- a/orm/school.go +++ b/orm/school.go @@ -139,10 +139,10 @@ func (model *School) Create(args map[string]string, w http.ResponseWriter, r *ht if err := DB().Where("user_id = ?", user.ID).Find(&school).Error; err != nil { return nil, err } - err := setFlashMessage(w, r, "schoolExists") - if err != nil { - return nil, err - } + // err := setFlashMessage(w, r, "schoolExists") + // if err != nil { + // return nil, err + // } return school, nil } else if err != nil { return nil, err @@ -164,7 +164,7 @@ func (model *School) Read(args map[string]string, w http.ResponseWriter, r *http id := args["id"] if isSchool(r) && id != getUserIDFromToken(r) { - setFlashMessage(w, r, "notAuthorized") + // setFlashMessage(w, r, "notAuthorized") return nil, errors.NotAuthorized } @@ -213,10 +213,10 @@ func (model *School) Update(args map[string]string, w http.ResponseWriter, r *ht // Check if the modified school code belong to an existing school. if user, err := school.(*School).exists(); err == nil && user != nil { if user.ID != school.(*School).UserID { - err := setFlashMessage(w, r, "schoolExists") - if err != nil { - return nil, err - } + // err := setFlashMessage(w, r, "schoolExists") + // if err != nil { + // return nil, err + // } return school, nil } } else if err != nil { diff --git a/renderer/renderer.go b/renderer/renderer.go index 105f970d..2c8f2c99 100644 --- a/renderer/renderer.go +++ b/renderer/renderer.go @@ -54,6 +54,7 @@ var ( contentTypeToFormat = map[string]string{ "application/x-www-form-urlencoded": "html", + "text/html; charset=utf-8": "html", "application/json": "json", "text/csv; charset=utf-8": "csv", "application/pdf": "pdf",