mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-01-09 14:28:05 -05:00
fix: add validation for callback URLs (#929)
This commit is contained in:
@@ -31,8 +31,8 @@ type OidcClientWithAllowedGroupsCountDto struct {
|
||||
|
||||
type OidcClientUpdateDto struct {
|
||||
Name string `json:"name" binding:"required,max=50" unorm:"nfc"`
|
||||
CallbackURLs []string `json:"callbackURLs"`
|
||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs"`
|
||||
CallbackURLs []string `json:"callbackURLs" binding:"omitempty,dive,callback_url"`
|
||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs" binding:"omitempty,dive,callback_url"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
PkceEnabled bool `json:"pkceEnabled"`
|
||||
RequiresReauthentication bool `json:"requiresReauthentication"`
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||
@@ -23,32 +25,34 @@ func init() {
|
||||
// Maximum allowed value for TTLs
|
||||
const maxTTL = 31 * 24 * time.Hour
|
||||
|
||||
// Errors here are development-time ones
|
||||
err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool {
|
||||
if err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool {
|
||||
return ValidateUsername(fl.Field().String())
|
||||
})
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
panic("Failed to register custom validation for username: " + err.Error())
|
||||
}
|
||||
|
||||
err = v.RegisterValidation("client_id", func(fl validator.FieldLevel) bool {
|
||||
if err := v.RegisterValidation("client_id", func(fl validator.FieldLevel) bool {
|
||||
return ValidateClientID(fl.Field().String())
|
||||
})
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
panic("Failed to register custom validation for client_id: " + err.Error())
|
||||
}
|
||||
|
||||
err = v.RegisterValidation("ttl", func(fl validator.FieldLevel) bool {
|
||||
if err := v.RegisterValidation("ttl", func(fl validator.FieldLevel) bool {
|
||||
ttl, ok := fl.Field().Interface().(utils.JSONDuration)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// Allow zero, which means the field wasn't set
|
||||
return ttl.Duration == 0 || ttl.Duration > time.Second && ttl.Duration <= maxTTL
|
||||
})
|
||||
if err != nil {
|
||||
return ttl.Duration == 0 || (ttl.Duration > time.Second && ttl.Duration <= maxTTL)
|
||||
}); err != nil {
|
||||
panic("Failed to register custom validation for ttl: " + err.Error())
|
||||
}
|
||||
|
||||
if err := v.RegisterValidation("callback_url", func(fl validator.FieldLevel) bool {
|
||||
return ValidateCallbackURL(fl.Field().String())
|
||||
}); err != nil {
|
||||
panic("Failed to register custom validation for callback_url: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateUsername validates username inputs
|
||||
@@ -60,3 +64,24 @@ func ValidateUsername(username string) bool {
|
||||
func ValidateClientID(clientID string) bool {
|
||||
return validateClientIDRegex.MatchString(clientID)
|
||||
}
|
||||
|
||||
// ValidateCallbackURL validates callback URLs with support for wildcards
|
||||
func ValidateCallbackURL(raw string) bool {
|
||||
if raw == "*" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Replace all '*' with 'x' to check if the rest is still a valid URI
|
||||
test := strings.ReplaceAll(raw, "*", "x")
|
||||
|
||||
u, err := url.Parse(test)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if !u.IsAbs() {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { m } from '$lib/paraglide/messages';
|
||||
import z from 'zod/v4';
|
||||
|
||||
export const emptyToUndefined = <T>(validation: z.ZodType<T>) =>
|
||||
@@ -7,3 +8,21 @@ export const optionalUrl = z
|
||||
.url()
|
||||
.optional()
|
||||
.or(z.literal('').transform(() => undefined));
|
||||
|
||||
export const callbackUrlSchema = z
|
||||
.string()
|
||||
.nonempty()
|
||||
.refine(
|
||||
(val) => {
|
||||
if (val === '*') return true;
|
||||
try {
|
||||
new URL(val.replace(/\*/g, 'x'));
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
},
|
||||
{
|
||||
message: m.invalid_redirect_url()
|
||||
}
|
||||
);
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import { preventDefault } from '$lib/utils/event-util';
|
||||
import { createForm } from '$lib/utils/form-util';
|
||||
import { cn } from '$lib/utils/style';
|
||||
import { emptyToUndefined, optionalUrl } from '$lib/utils/zod-util';
|
||||
import { callbackUrlSchema, emptyToUndefined, optionalUrl } from '$lib/utils/zod-util';
|
||||
import { LucideChevronDown } from '@lucide/svelte';
|
||||
import { slide } from 'svelte/transition';
|
||||
import { z } from 'zod/v4';
|
||||
@@ -65,8 +65,8 @@
|
||||
.optional()
|
||||
),
|
||||
name: z.string().min(2).max(50),
|
||||
callbackURLs: z.array(z.string().nonempty()).default([]),
|
||||
logoutCallbackURLs: z.array(z.string().nonempty()),
|
||||
callbackURLs: z.array(callbackUrlSchema).default([]),
|
||||
logoutCallbackURLs: z.array(callbackUrlSchema).default([]),
|
||||
isPublic: z.boolean(),
|
||||
pkceEnabled: z.boolean(),
|
||||
requiresReauthentication: z.boolean(),
|
||||
|
||||
Reference in New Issue
Block a user