feat: new user callback url for sso and generic oauth (#1176)

This commit is contained in:
kzlar
2025-01-14 04:58:30 -05:00
committed by GitHub
parent 68c4d151b8
commit 1a97e31db2
3 changed files with 75 additions and 21 deletions

View File

@@ -274,6 +274,12 @@ export const genericOAuth = (options: GenericOAuthOptions) => {
description: "The URL to redirect to if an error occurs",
})
.optional(),
newUserCallbackURL: z
.string({
description:
"The URL to redirect to after login if the user is new",
})
.optional(),
disableRedirect: z
.boolean({
description: "Disable redirect",
@@ -460,7 +466,8 @@ export const genericOAuth = (options: GenericOAuthOptions) => {
let tokens: OAuth2Tokens | undefined = undefined;
const parsedState = await parseState(ctx);
const { callbackURL, codeVerifier, errorURL } = parsedState;
const { callbackURL, codeVerifier, errorURL, newUserURL } =
parsedState;
const code = ctx.query.code;
let finalTokenUrl = provider.tokenUrl;
@@ -557,10 +564,14 @@ export const genericOAuth = (options: GenericOAuthOptions) => {
});
let toRedirectTo: string;
try {
const url = new URL(callbackURL);
const url = result.isRegister
? newUserURL || callbackURL
: callbackURL;
toRedirectTo = url.toString();
} catch {
toRedirectTo = callbackURL;
toRedirectTo = result.isRegister
? newUserURL || callbackURL
: callbackURL;
}
throw ctx.redirect(toRedirectTo);
},

View File

@@ -15,20 +15,7 @@ describe("oauth2", async () => {
const clientId = "test-client-id";
const clientSecret = "test-client-secret";
beforeAll(async () => {
await server.issuer.keys.generate("RS256");
server.issuer.on;
// Start the server
await server.start(8080, "localhost");
console.log("Issuer URL:", server.issuer.url); // -> http://localhost:8080
});
afterAll(async () => {
await server.stop();
});
const { customFetchImpl } = await getTestInstance({
const { customFetchImpl, auth } = await getTestInstance({
plugins: [
genericOAuth({
config: [
@@ -54,6 +41,24 @@ describe("oauth2", async () => {
},
});
beforeAll(async () => {
const context = await auth.$context;
await context.internalAdapter.createUser({
email: "oauth2@test.com",
name: "OAuth2 Test",
});
await server.issuer.keys.generate("RS256");
server.issuer.on;
// Start the server
await server.start(8080, "localhost");
console.log("Issuer URL:", server.issuer.url); // -> http://localhost:8080
});
afterAll(async () => {
await server.stop();
});
server.service.on("beforeUserinfo", (userInfoResponse, req) => {
userInfoResponse.body = {
email: "oauth2@test.com",
@@ -99,6 +104,7 @@ describe("oauth2", async () => {
const signInRes = await authClient.signIn.oauth2({
providerId: "test",
callbackURL: "http://localhost:3000/dashboard",
newUserCallbackURL: "http://localhost:3000/new_user",
});
expect(signInRes.data).toMatchObject({
url: expect.stringContaining("http://localhost:8080/authorize"),
@@ -111,11 +117,41 @@ describe("oauth2", async () => {
expect(callbackURL).toBe("http://localhost:3000/dashboard");
});
it("should redirect to the provider and handle the response for a new user", async () => {
server.service.once("beforeUserinfo", (userInfoResponse) => {
userInfoResponse.body = {
email: "oauth2-2@test.com",
name: "OAuth2 Test 2",
sub: "oauth2-2",
picture: "https://test.com/picture.png",
email_verified: true,
};
userInfoResponse.statusCode = 200;
});
let headers = new Headers();
const signInRes = await authClient.signIn.oauth2({
providerId: "test",
callbackURL: "http://localhost:3000/dashboard",
newUserCallbackURL: "http://localhost:3000/new_user",
});
expect(signInRes.data).toMatchObject({
url: expect.stringContaining("http://localhost:8080/authorize"),
redirect: true,
});
const callbackURL = await simulateOAuthFlow(
signInRes.data?.url || "",
headers,
);
expect(callbackURL).toBe("http://localhost:3000/new_user");
});
it("should redirect to the provider and handle the response after linked", async () => {
let headers = new Headers();
const res = await authClient.signIn.oauth2({
providerId: "test",
callbackURL: "http://localhost:3000/dashboard",
newUserCallbackURL: "http://localhost:3000/new_user",
});
const callbackURL = await simulateOAuthFlow(res.data?.url || "", headers);
expect(callbackURL).toBe("http://localhost:3000/dashboard");
@@ -125,6 +161,7 @@ describe("oauth2", async () => {
const res = await authClient.signIn.oauth2({
providerId: "invalid-provider",
callbackURL: "http://localhost:3000/dashboard",
newUserCallbackURL: "http://localhost:3000/new_user",
});
expect(res.error?.status).toBe(400);
});
@@ -146,6 +183,7 @@ describe("oauth2", async () => {
{
providerId: "test",
callbackURL: "http://localhost:3000/dashboard",
newUserCallbackURL: "http://localhost:3000/new_user",
},
{
onSuccess(context) {
@@ -169,7 +207,7 @@ describe("oauth2", async () => {
});
it("should work with custom redirect uri", async () => {
const { customFetchImpl } = await getTestInstance({
const { customFetchImpl, auth } = await getTestInstance({
plugins: [
genericOAuth({
config: [
@@ -198,6 +236,7 @@ describe("oauth2", async () => {
const res = await authClient.signIn.oauth2({
providerId: "test2",
callbackURL: "http://localhost:3000/dashboard",
newUserCallbackURL: "http://localhost:3000/new_user",
});
expect(res.data?.url).toContain("http://localhost:8080/authorize");
const headers = new Headers();
@@ -206,6 +245,6 @@ describe("oauth2", async () => {
headers,
customFetchImpl,
);
expect(callbackURL).toBe("http://localhost:3000/dashboard");
expect(callbackURL).toBe("http://localhost:3000/new_user");
});
});

View File

@@ -679,10 +679,14 @@ export const sso = (options?: SSOOptions) => {
});
let toRedirectTo: string;
try {
const url = new URL(callbackURL);
const url = linked.isRegister
? newUserURL || callbackURL
: callbackURL;
toRedirectTo = url.toString();
} catch {
toRedirectTo = callbackURL;
toRedirectTo = linked.isRegister
? newUserURL || callbackURL
: callbackURL;
}
throw ctx.redirect(toRedirectTo);
},