Commit 53d71c6d authored by Julien Schröter's avatar Julien Schröter Committed by Julien Schröter

Extend TokenFactory by creation and processing of PasswordResetTokens

parent e60101b7
......@@ -12,10 +12,11 @@ type TokenType int
const (
ConfirmRegistrationTokenType TokenType = iota
ForgotPasswordTokenType
PasswordResetTokenType
)
var ErrTokenExpired = errors.New("token expired")
var ErrOddTokenType = errors.New("odd token type")
func GetTokenFactory(secret []byte) TokenFactory {
return &tokenFactory{secret: secret}
......@@ -24,6 +25,8 @@ func GetTokenFactory(secret []byte) TokenFactory {
type TokenFactory interface {
CreateConfirmRegistrationToken(userID uint32, email string) (string, error)
ParseConfirmRegistrationToken(token string) (uint32, string, error)
CreatePasswordResetToken(userID uint32, hash string) (string, error)
ParsePasswordResetToken(token string) (uint32, string, error)
}
type tokenFactory struct {
......@@ -48,6 +51,15 @@ func (f *tokenFactory) CreateConfirmRegistrationToken(userID uint32, email strin
})
}
func (f *tokenFactory) CreatePasswordResetToken(userID uint32, hash string) (string, error) {
return f.CreateToken(jwt.MapClaims{
"type": PasswordResetTokenType,
"userID": userID,
"hash": hash,
"exp": time.Now().Add(time.Hour * 24).Unix(),
})
}
func (f *tokenFactory) Parse(tokenString string) (jwt.MapClaims, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
......@@ -85,7 +97,7 @@ func (f *tokenFactory) ParseConfirmRegistrationToken(token string) (uint32, stri
t, ok := claims["type"]
if !ok || TokenType(t.(float64)) != ConfirmRegistrationTokenType {
return 0, "", errors.New("invalid token type")
return 0, "", ErrOddTokenType
}
userID, ok := claims["userID"]
......@@ -110,3 +122,41 @@ func (f *tokenFactory) ParseConfirmRegistrationToken(token string) (uint32, stri
return uint32(numUserID), strEmail, expired
}
func (f *tokenFactory) ParsePasswordResetToken(token string) (uint32, string, error) {
var expired error
claims, err := f.Parse(token)
if err != nil {
if err != ErrTokenExpired {
return 0, "", err
}
expired = err
}
t, ok := claims["type"]
if !ok || TokenType(t.(float64)) != PasswordResetTokenType {
return 0, "", ErrOddTokenType
}
userID, ok := claims["userID"]
if !ok {
return 0, "", errors.New("unable to read userID")
}
numUserID, ok := userID.(float64)
if !ok {
return 0, "", errors.New("failed to parse userID")
}
hash, ok := claims["hash"]
if !ok {
return 0, "", errors.New("failed to read hash")
}
strHash, ok := hash.(string)
if !ok {
return 0, "", errors.New("failed to parse hash")
}
return uint32(numUserID), strHash, expired
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment