mirror of
https://github.com/better-auth/better-auth.git
synced 2026-05-22 14:21:55 -05:00
feat(oauth-provider): public client prelogin endpoint (#8214)
This commit is contained in:
@@ -195,6 +195,33 @@ type getOAuthClientPublic = {
|
||||
```
|
||||
</APIMethod>
|
||||
|
||||
#### Get Public Client Prelogin
|
||||
|
||||
To obtain a public client prior to login, you must first enable the endpoint in your configuration:
|
||||
|
||||
```ts title="auth.ts"
|
||||
oauthProvider({
|
||||
allowPublicClientPrelogin: true,
|
||||
})
|
||||
```
|
||||
|
||||
Then, the following endpoint will obtain public client information.
|
||||
|
||||
<APIMethod path="/oauth2/public-client-prelogin" method="POST">
|
||||
```ts
|
||||
type getOAuthClientPublicPrelogin = {
|
||||
/**
|
||||
* The OAuth client's client_id
|
||||
*/
|
||||
client_id: string,
|
||||
/**
|
||||
* Valid oauth query parameters (Sent automatically when using the provided client)
|
||||
*/
|
||||
oauth_query: string
|
||||
}
|
||||
```
|
||||
</APIMethod>
|
||||
|
||||
#### List Clients
|
||||
|
||||
To obtain a list of clients owned by a specific user or organization, use the following endpoint:
|
||||
|
||||
17
packages/oauth-provider/src/middleware/index.ts
Normal file
17
packages/oauth-provider/src/middleware/index.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
import { APIError, createAuthMiddleware } from "better-auth/api";
|
||||
import type { OAuthOptions, Scope } from "../types";
|
||||
import { verifyOAuthQueryParams } from "../utils";
|
||||
|
||||
export const publicSessionMiddleware = (opts: OAuthOptions<Scope[]>) =>
|
||||
createAuthMiddleware(async (ctx) => {
|
||||
if (!opts.allowPublicClientPrelogin) {
|
||||
throw new APIError("BAD_REQUEST");
|
||||
}
|
||||
const query = ctx.body.oauth_query;
|
||||
const isValid = await verifyOAuthQueryParams(query, ctx.context.secret);
|
||||
if (!isValid) {
|
||||
throw new APIError("UNAUTHORIZED", {
|
||||
error: "invalid_signature",
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -9,7 +9,6 @@ import {
|
||||
sessionMiddleware,
|
||||
} from "better-auth/api";
|
||||
import { parseSetCookieHeader } from "better-auth/cookies";
|
||||
import { constantTimeEqual, makeSignature } from "better-auth/crypto";
|
||||
import { mergeSchema } from "better-auth/db";
|
||||
import type { BetterAuthPlugin } from "better-auth/types";
|
||||
import * as z from "zod";
|
||||
@@ -28,7 +27,11 @@ import { tokenEndpoint } from "./token";
|
||||
import type { OAuthOptions, Scope } from "./types";
|
||||
import { SafeUrlSchema } from "./types/zod";
|
||||
import { userInfoEndpoint } from "./userinfo";
|
||||
import { deleteFromPrompt, getJwtPlugin } from "./utils";
|
||||
import {
|
||||
deleteFromPrompt,
|
||||
getJwtPlugin,
|
||||
verifyOAuthQueryParams,
|
||||
} from "./utils";
|
||||
|
||||
declare module "@better-auth/core" {
|
||||
interface BetterAuthPluginRegistry<AuthOptions, Options> {
|
||||
@@ -209,27 +212,20 @@ export const oauthProvider = <O extends OAuthOptions<Scope[]>>(options: O) => {
|
||||
handler: createAuthMiddleware(async (ctx) => {
|
||||
// Verify query signature
|
||||
const query = ctx.body.oauth_query;
|
||||
let queryParams = new URLSearchParams(query);
|
||||
const sig = queryParams.get("sig");
|
||||
const exp = Number(queryParams.get("exp"));
|
||||
queryParams.delete("sig");
|
||||
queryParams = new URLSearchParams(queryParams);
|
||||
const verifySig = await makeSignature(
|
||||
queryParams.toString(),
|
||||
const isValid = await verifyOAuthQueryParams(
|
||||
query,
|
||||
ctx.context.secret,
|
||||
);
|
||||
if (
|
||||
!sig ||
|
||||
!constantTimeEqual(sig, verifySig) ||
|
||||
new Date(exp * 1000) < new Date()
|
||||
) {
|
||||
if (!isValid) {
|
||||
throw new APIError("BAD_REQUEST", {
|
||||
error: "invalid_signature",
|
||||
});
|
||||
}
|
||||
const queryParams = new URLSearchParams(query);
|
||||
queryParams.delete("sig");
|
||||
queryParams.delete("exp");
|
||||
await oAuthState.set({
|
||||
query: new URLSearchParams(queryParams).toString(),
|
||||
query: queryParams.toString(),
|
||||
});
|
||||
|
||||
// If path starts oauth2 authorize (ie /sign-in/social, /sign-in/oauth2), add to additional data body
|
||||
@@ -1302,6 +1298,8 @@ export const oauthProvider = <O extends OAuthOptions<Scope[]>>(options: O) => {
|
||||
createOAuthClient: oauthClientEndpoints.createOAuthClient(opts),
|
||||
getOAuthClient: oauthClientEndpoints.getOAuthClient(opts),
|
||||
getOAuthClientPublic: oauthClientEndpoints.getOAuthClientPublic(opts),
|
||||
getOAuthClientPublicPrelogin:
|
||||
oauthClientEndpoints.getOAuthClientPublicPrelogin(opts),
|
||||
getOAuthClients: oauthClientEndpoints.getOAuthClients(opts),
|
||||
adminUpdateOAuthClient: oauthClientEndpoints.adminUpdateOAuthClient(opts),
|
||||
updateOAuthClient: oauthClientEndpoints.updateOAuthClient(opts),
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createAuthClient } from "better-auth/client";
|
||||
import { makeSignature } from "better-auth/crypto";
|
||||
import { jwt } from "better-auth/plugins/jwt";
|
||||
import { getTestInstance } from "better-auth/test";
|
||||
import { describe, expect, it } from "vitest";
|
||||
@@ -11,7 +12,7 @@ describe("oauthClient", async () => {
|
||||
const baseUrl = "http://localhost:3000";
|
||||
const rpBaseUrl = "http://localhost:5000";
|
||||
const redirectUri = `${rpBaseUrl}/api/auth/oauth2/callback/${providerId}`;
|
||||
const { signInWithTestUser, customFetchImpl } = await getTestInstance({
|
||||
const { auth, signInWithTestUser, customFetchImpl } = await getTestInstance({
|
||||
baseURL: baseUrl,
|
||||
plugins: [
|
||||
oauthProvider({
|
||||
@@ -21,6 +22,7 @@ describe("oauthClient", async () => {
|
||||
oauthAuthServerConfig: true,
|
||||
openidConfig: true,
|
||||
},
|
||||
allowPublicClientPrelogin: true,
|
||||
}),
|
||||
jwt(),
|
||||
],
|
||||
@@ -103,6 +105,27 @@ describe("oauthClient", async () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should get public-only information about a client prelogin", async () => {
|
||||
// Creates mock valid search params
|
||||
const signedParams = new URLSearchParams({
|
||||
exp: `${Math.floor(Date.now() / 1000) + 60}`,
|
||||
});
|
||||
const sig = await makeSignature(
|
||||
signedParams.toString(),
|
||||
(auth.options as unknown as { secret: string }).secret,
|
||||
);
|
||||
signedParams.set("sig", sig);
|
||||
|
||||
const client = await authClient.oauth2.publicClientPrelogin({
|
||||
client_id: oauthUiClient.client_id,
|
||||
oauth_query: signedParams.toString(),
|
||||
});
|
||||
expect(client.data).toMatchObject({
|
||||
client_id: oauthUiClient.client_id,
|
||||
...testUiClientInput,
|
||||
});
|
||||
});
|
||||
|
||||
it("should get user's clients", async () => {
|
||||
const clients = await authClient.oauth2.getClients();
|
||||
expect(clients?.data?.length).toBe(3);
|
||||
|
||||
@@ -53,10 +53,11 @@ export async function getClientEndpoint(
|
||||
* This is commonly used to display information on login flow pages.
|
||||
*/
|
||||
export async function getClientPublicEndpoint(
|
||||
ctx: GenericEndpointContext & { query: { client_id: string } },
|
||||
ctx: GenericEndpointContext,
|
||||
opts: OAuthOptions<Scope[]>,
|
||||
clientId: string,
|
||||
) {
|
||||
const client = await getClient(ctx, opts, ctx.query.client_id);
|
||||
const client = await getClient(ctx, opts, clientId);
|
||||
if (!client) {
|
||||
throw new APIError("NOT_FOUND", {
|
||||
error_description: "client not found",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { createAuthEndpoint, sessionMiddleware } from "better-auth/api";
|
||||
import * as z from "zod";
|
||||
import { publicSessionMiddleware } from "../middleware";
|
||||
import { createOAuthClientEndpoint } from "../register";
|
||||
import type { OAuthOptions, Scope } from "../types";
|
||||
import { SafeUrlSchema } from "../types/zod";
|
||||
@@ -451,12 +452,35 @@ export const getOAuthClientPublic = (opts: OAuthOptions<Scope[]>) =>
|
||||
}),
|
||||
metadata: {
|
||||
openapi: {
|
||||
description: "Gets publically available client fields",
|
||||
description: "Gets publicly available client fields",
|
||||
},
|
||||
},
|
||||
},
|
||||
async (ctx) => {
|
||||
return getClientPublicEndpoint(ctx, opts);
|
||||
const clientId = ctx.query.client_id;
|
||||
return getClientPublicEndpoint(ctx, opts, clientId);
|
||||
},
|
||||
);
|
||||
|
||||
export const getOAuthClientPublicPrelogin = (opts: OAuthOptions<Scope[]>) =>
|
||||
createAuthEndpoint(
|
||||
"/oauth2/public-client-prelogin",
|
||||
{
|
||||
method: "POST",
|
||||
use: [publicSessionMiddleware(opts)],
|
||||
body: z.object({
|
||||
client_id: z.string(),
|
||||
oauth_query: z.string().optional(),
|
||||
}),
|
||||
metadata: {
|
||||
openapi: {
|
||||
description: "Gets publicly available client fields (prior to login)",
|
||||
},
|
||||
},
|
||||
},
|
||||
async (ctx) => {
|
||||
const clientId = ctx.body.client_id;
|
||||
return getClientPublicEndpoint(ctx, opts, clientId);
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -115,6 +115,11 @@ export interface OAuthOptions<
|
||||
scopeExpirations?: {
|
||||
[K in Scopes[number]]?: number | string | Date;
|
||||
};
|
||||
/**
|
||||
* Allows /oauth2/public-client-prelogin endpoint to be
|
||||
* requestable prior to login via a valid oauth_query.
|
||||
*/
|
||||
allowPublicClientPrelogin?: boolean;
|
||||
/**
|
||||
* Allow unauthenticated dynamic client registration.
|
||||
*
|
||||
|
||||
@@ -62,6 +62,22 @@ export const getJwtPlugin = (ctx: AuthContext) => {
|
||||
|
||||
const cachedTrustedClients = new TTLCache<string, SchemaClient<Scope[]>>();
|
||||
|
||||
export async function verifyOAuthQueryParams(
|
||||
oauth_query: string,
|
||||
secret: string,
|
||||
) {
|
||||
const queryParams = new URLSearchParams(oauth_query);
|
||||
const sig = queryParams.get("sig");
|
||||
const exp = Number(queryParams.get("exp"));
|
||||
queryParams.delete("sig");
|
||||
const verifySig = await makeSignature(queryParams.toString(), secret);
|
||||
return (
|
||||
!!sig &&
|
||||
constantTimeEqual(sig, verifySig) &&
|
||||
new Date(exp * 1000) >= new Date()
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a client by ID, checking trusted clients first, then database
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user