mirror of
https://github.com/ollama/ollama.git
synced 2026-03-08 23:04:13 -05:00
Reapply "don't require pulling stubs for cloud models"
This reverts commit 97d2f05a6d.
This commit is contained in:
460
server/cloud_proxy.go
Normal file
460
server/cloud_proxy.go
Normal file
@@ -0,0 +1,460 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
internalcloud "github.com/ollama/ollama/internal/cloud"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultCloudProxyBaseURL = "https://ollama.com:443"
|
||||
defaultCloudProxySigningHost = "ollama.com"
|
||||
cloudProxyBaseURLEnv = "OLLAMA_CLOUD_BASE_URL"
|
||||
legacyCloudAnthropicKey = "legacy_cloud_anthropic_web_search"
|
||||
)
|
||||
|
||||
var (
|
||||
cloudProxyBaseURL = defaultCloudProxyBaseURL
|
||||
cloudProxySigningHost = defaultCloudProxySigningHost
|
||||
cloudProxySignRequest = signCloudProxyRequest
|
||||
cloudProxySigninURL = signinURL
|
||||
)
|
||||
|
||||
var hopByHopHeaders = map[string]struct{}{
|
||||
"connection": {},
|
||||
"content-length": {},
|
||||
"proxy-connection": {},
|
||||
"keep-alive": {},
|
||||
"proxy-authenticate": {},
|
||||
"proxy-authorization": {},
|
||||
"te": {},
|
||||
"trailer": {},
|
||||
"transfer-encoding": {},
|
||||
"upgrade": {},
|
||||
}
|
||||
|
||||
func init() {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL(envconfig.Var(cloudProxyBaseURLEnv), mode)
|
||||
if err != nil {
|
||||
slog.Warn("ignoring cloud base URL override", "env", cloudProxyBaseURLEnv, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
cloudProxyBaseURL = baseURL
|
||||
cloudProxySigningHost = signingHost
|
||||
|
||||
if overridden {
|
||||
slog.Info("cloud base URL override enabled", "env", cloudProxyBaseURLEnv, "url", cloudProxyBaseURL, "mode", mode)
|
||||
}
|
||||
}
|
||||
|
||||
func cloudPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if c.Request.Method != http.MethodPost {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(drifkin): Avoid full-body buffering here for model detection.
|
||||
// A future optimization can parse just enough JSON to read "model" (and
|
||||
// optionally short-circuit cloud-disabled explicit-cloud requests) while
|
||||
// preserving raw passthrough semantics.
|
||||
body, err := readRequestBody(c.Request)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
model, ok := extractModelField(body)
|
||||
if !ok {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(model)
|
||||
if err != nil || modelRef.Source != modelSourceCloud {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
normalizedBody, err := replaceJSONModelField(body, modelRef.Base)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// TEMP(drifkin): keep Anthropic web search requests on the local middleware
|
||||
// path so WebSearchAnthropicWriter can orchestrate follow-up calls.
|
||||
if c.Request.URL.Path == "/v1/messages" {
|
||||
if hasAnthropicWebSearchTool(body) {
|
||||
c.Set(legacyCloudAnthropicKey, true)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
proxyCloudRequest(c, normalizedBody, disabledOperation)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func cloudModelPathPassthroughMiddleware(disabledOperation string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
modelName := strings.TrimSpace(c.Param("model"))
|
||||
if modelName == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(modelName)
|
||||
if err != nil || modelRef.Source != modelSourceCloud {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
proxyPath := "/v1/models/" + modelRef.Base
|
||||
proxyCloudRequestWithPath(c, nil, proxyPath, disabledOperation)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func proxyCloudJSONRequest(c *gin.Context, payload any, disabledOperation string) {
|
||||
// TEMP(drifkin): we currently split out this `WithPath` method because we are
|
||||
// mapping `/v1/messages` + web_search to `/api/chat` temporarily. Once we
|
||||
// stop doing this, we can inline this method.
|
||||
proxyCloudJSONRequestWithPath(c, payload, c.Request.URL.Path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudJSONRequestWithPath(c *gin.Context, payload any, path string, disabledOperation string) {
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
proxyCloudRequestWithPath(c, body, path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudRequest(c *gin.Context, body []byte, disabledOperation string) {
|
||||
proxyCloudRequestWithPath(c, body, c.Request.URL.Path, disabledOperation)
|
||||
}
|
||||
|
||||
func proxyCloudRequestWithPath(c *gin.Context, body []byte, path string, disabledOperation string) {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(disabledOperation)})
|
||||
return
|
||||
}
|
||||
|
||||
baseURL, err := url.Parse(cloudProxyBaseURL)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
targetURL := baseURL.ResolveReference(&url.URL{
|
||||
Path: path,
|
||||
RawQuery: c.Request.URL.RawQuery,
|
||||
})
|
||||
|
||||
outReq, err := http.NewRequestWithContext(c.Request.Context(), c.Request.Method, targetURL.String(), bytes.NewReader(body))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
copyProxyRequestHeaders(outReq.Header, c.Request.Header)
|
||||
if outReq.Header.Get("Content-Type") == "" && len(body) > 0 {
|
||||
outReq.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
if err := cloudProxySignRequest(outReq.Context(), outReq); err != nil {
|
||||
slog.Warn("cloud proxy signing failed", "error", err)
|
||||
writeCloudUnauthorized(c)
|
||||
return
|
||||
}
|
||||
|
||||
// TODO(drifkin): Add phase-specific proxy timeouts.
|
||||
// Connect/TLS/TTFB should have bounded timeouts, but once streaming starts
|
||||
// we should not enforce a short total timeout for long-lived responses.
|
||||
resp, err := http.DefaultClient.Do(outReq)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
copyProxyResponseHeaders(c.Writer.Header(), resp.Header)
|
||||
c.Status(resp.StatusCode)
|
||||
|
||||
if err := copyProxyResponseBody(c.Writer, resp.Body); err != nil {
|
||||
c.Error(err) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
func replaceJSONModelField(body []byte, model string) ([]byte, error) {
|
||||
if len(body) == 0 {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelJSON, err := json.Marshal(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload["model"] = modelJSON
|
||||
|
||||
return json.Marshal(payload)
|
||||
}
|
||||
|
||||
func readRequestBody(r *http.Request) ([]byte, error) {
|
||||
if r.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func extractModelField(body []byte) (string, bool) {
|
||||
if len(body) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var payload map[string]json.RawMessage
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
raw, ok := payload["model"]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
var model string
|
||||
if err := json.Unmarshal(raw, &model); err != nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
model = strings.TrimSpace(model)
|
||||
return model, model != ""
|
||||
}
|
||||
|
||||
func hasAnthropicWebSearchTool(body []byte) bool {
|
||||
if len(body) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
Tools []struct {
|
||||
Type string `json:"type"`
|
||||
} `json:"tools"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, tool := range payload.Tools {
|
||||
if strings.HasPrefix(strings.TrimSpace(tool.Type), "web_search") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func writeCloudUnauthorized(c *gin.Context) {
|
||||
signinURL, err := cloudProxySigninURL()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized", "signin_url": signinURL})
|
||||
}
|
||||
|
||||
func signCloudProxyRequest(ctx context.Context, req *http.Request) error {
|
||||
if !strings.EqualFold(req.URL.Hostname(), cloudProxySigningHost) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ts := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
challenge := buildCloudSignatureChallenge(req, ts)
|
||||
signature, err := auth.Sign(ctx, []byte(challenge))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", signature)
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildCloudSignatureChallenge(req *http.Request, ts string) string {
|
||||
query := req.URL.Query()
|
||||
query.Set("ts", ts)
|
||||
req.URL.RawQuery = query.Encode()
|
||||
|
||||
return fmt.Sprintf("%s,%s", req.Method, req.URL.RequestURI())
|
||||
}
|
||||
|
||||
func resolveCloudProxyBaseURL(rawOverride string, runMode string) (baseURL string, signingHost string, overridden bool, err error) {
|
||||
baseURL = defaultCloudProxyBaseURL
|
||||
signingHost = defaultCloudProxySigningHost
|
||||
|
||||
rawOverride = strings.TrimSpace(rawOverride)
|
||||
if rawOverride == "" {
|
||||
return baseURL, signingHost, false, nil
|
||||
}
|
||||
|
||||
u, err := url.Parse(rawOverride)
|
||||
if err != nil {
|
||||
return "", "", false, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
if u.Scheme == "" || u.Host == "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: scheme and host are required")
|
||||
}
|
||||
if u.User != nil {
|
||||
return "", "", false, fmt.Errorf("invalid URL: userinfo is not allowed")
|
||||
}
|
||||
if u.Path != "" && u.Path != "/" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: path is not allowed")
|
||||
}
|
||||
if u.RawQuery != "" || u.Fragment != "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: query and fragment are not allowed")
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
if host == "" {
|
||||
return "", "", false, fmt.Errorf("invalid URL: host is required")
|
||||
}
|
||||
|
||||
loopback := isLoopbackHost(host)
|
||||
if runMode == gin.ReleaseMode && !loopback {
|
||||
return "", "", false, fmt.Errorf("non-loopback cloud override is not allowed in release mode")
|
||||
}
|
||||
if !loopback && !strings.EqualFold(u.Scheme, "https") {
|
||||
return "", "", false, fmt.Errorf("non-loopback cloud override must use https")
|
||||
}
|
||||
|
||||
u.Path = ""
|
||||
u.RawPath = ""
|
||||
u.RawQuery = ""
|
||||
u.Fragment = ""
|
||||
|
||||
return u.String(), strings.ToLower(host), true, nil
|
||||
}
|
||||
|
||||
func isLoopbackHost(host string) bool {
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
|
||||
ip := net.ParseIP(host)
|
||||
return ip != nil && ip.IsLoopback()
|
||||
}
|
||||
|
||||
func copyProxyRequestHeaders(dst, src http.Header) {
|
||||
connectionTokens := connectionHeaderTokens(src)
|
||||
for key, values := range src {
|
||||
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Del(key)
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyProxyResponseHeaders(dst, src http.Header) {
|
||||
connectionTokens := connectionHeaderTokens(src)
|
||||
for key, values := range src {
|
||||
if isHopByHopHeader(key) || isConnectionTokenHeader(key, connectionTokens) {
|
||||
continue
|
||||
}
|
||||
|
||||
dst.Del(key)
|
||||
for _, value := range values {
|
||||
dst.Add(key, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func copyProxyResponseBody(dst http.ResponseWriter, src io.Reader) error {
|
||||
flusher, canFlush := dst.(http.Flusher)
|
||||
buf := make([]byte, 32*1024)
|
||||
|
||||
for {
|
||||
n, err := src.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := dst.Write(buf[:n]); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
if canFlush {
|
||||
// TODO(drifkin): Consider conditional flushing so non-streaming
|
||||
// responses don't flush every write and can optimize throughput.
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func isHopByHopHeader(name string) bool {
|
||||
_, ok := hopByHopHeaders[strings.ToLower(name)]
|
||||
return ok
|
||||
}
|
||||
|
||||
func connectionHeaderTokens(header http.Header) map[string]struct{} {
|
||||
tokens := map[string]struct{}{}
|
||||
for _, raw := range header.Values("Connection") {
|
||||
for _, token := range strings.Split(raw, ",") {
|
||||
token = strings.TrimSpace(strings.ToLower(token))
|
||||
if token == "" {
|
||||
continue
|
||||
}
|
||||
tokens[token] = struct{}{}
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
func isConnectionTokenHeader(name string, tokens map[string]struct{}) bool {
|
||||
if len(tokens) == 0 {
|
||||
return false
|
||||
}
|
||||
_, ok := tokens[strings.ToLower(name)]
|
||||
return ok
|
||||
}
|
||||
154
server/cloud_proxy_test.go
Normal file
154
server/cloud_proxy_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestCopyProxyRequestHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "keep-alive, X-Trace-Hop, x-alt-hop")
|
||||
src.Add("X-Trace-Hop", "drop-me")
|
||||
src.Add("X-Alt-Hop", "drop-me-too")
|
||||
src.Add("Keep-Alive", "timeout=5")
|
||||
src.Add("X-End-To-End", "keep-me")
|
||||
|
||||
dst := http.Header{}
|
||||
copyProxyRequestHeaders(dst, src)
|
||||
|
||||
if got := dst.Get("Connection"); got != "" {
|
||||
t.Fatalf("expected Connection to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("Keep-Alive"); got != "" {
|
||||
t.Fatalf("expected Keep-Alive to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Trace-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Trace-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Alt-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Alt-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-End-To-End"); got != "keep-me" {
|
||||
t.Fatalf("expected X-End-To-End to be forwarded, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyProxyResponseHeaders_StripsConnectionTokenHeaders(t *testing.T) {
|
||||
src := http.Header{}
|
||||
src.Add("Connection", "X-Upstream-Hop")
|
||||
src.Add("X-Upstream-Hop", "drop-me")
|
||||
src.Add("Content-Type", "application/json")
|
||||
src.Add("X-Server-Trace", "keep-me")
|
||||
|
||||
dst := http.Header{}
|
||||
copyProxyResponseHeaders(dst, src)
|
||||
|
||||
if got := dst.Get("Connection"); got != "" {
|
||||
t.Fatalf("expected Connection to be stripped, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Upstream-Hop"); got != "" {
|
||||
t.Fatalf("expected X-Upstream-Hop to be stripped via Connection token, got %q", got)
|
||||
}
|
||||
if got := dst.Get("Content-Type"); got != "application/json" {
|
||||
t.Fatalf("expected Content-Type to be forwarded, got %q", got)
|
||||
}
|
||||
if got := dst.Get("X-Server-Trace"); got != "keep-me" {
|
||||
t.Fatalf("expected X-Server-Trace to be forwarded, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_Default(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("", gin.ReleaseMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if overridden {
|
||||
t.Fatal("expected override=false for empty input")
|
||||
}
|
||||
if baseURL != defaultCloudProxyBaseURL {
|
||||
t.Fatalf("expected default base URL %q, got %q", defaultCloudProxyBaseURL, baseURL)
|
||||
}
|
||||
if signingHost != defaultCloudProxySigningHost {
|
||||
t.Fatalf("expected default signing host %q, got %q", defaultCloudProxySigningHost, signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_ReleaseAllowsLoopback(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("http://localhost:8080", gin.ReleaseMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !overridden {
|
||||
t.Fatal("expected override=true")
|
||||
}
|
||||
if baseURL != "http://localhost:8080" {
|
||||
t.Fatalf("unexpected base URL: %q", baseURL)
|
||||
}
|
||||
if signingHost != "localhost" {
|
||||
t.Fatalf("unexpected signing host: %q", signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_ReleaseRejectsNonLoopback(t *testing.T) {
|
||||
_, _, _, err := resolveCloudProxyBaseURL("https://example.com", gin.ReleaseMode)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-loopback override in release mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_DevAllowsNonLoopbackHTTPS(t *testing.T) {
|
||||
baseURL, signingHost, overridden, err := resolveCloudProxyBaseURL("https://example.com:8443", gin.DebugMode)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !overridden {
|
||||
t.Fatal("expected override=true")
|
||||
}
|
||||
if baseURL != "https://example.com:8443" {
|
||||
t.Fatalf("unexpected base URL: %q", baseURL)
|
||||
}
|
||||
if signingHost != "example.com" {
|
||||
t.Fatalf("unexpected signing host: %q", signingHost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveCloudProxyBaseURL_DevRejectsNonLoopbackHTTP(t *testing.T) {
|
||||
_, _, _, err := resolveCloudProxyBaseURL("http://example.com", gin.DebugMode)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-loopback http override in dev mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCloudSignatureChallengeIncludesExistingQuery(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&foo=bar", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
got := buildCloudSignatureChallenge(req, "123")
|
||||
want := "POST,/v1/messages?beta=true&foo=bar&ts=123"
|
||||
if got != want {
|
||||
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if req.URL.RawQuery != "beta=true&foo=bar&ts=123" {
|
||||
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCloudSignatureChallengeOverwritesExistingTimestamp(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodPost, "https://ollama.com/v1/messages?beta=true&ts=999", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create request: %v", err)
|
||||
}
|
||||
|
||||
got := buildCloudSignatureChallenge(req, "123")
|
||||
want := "POST,/v1/messages?beta=true&ts=123"
|
||||
if got != want {
|
||||
t.Fatalf("challenge mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if req.URL.RawQuery != "beta=true&ts=123" {
|
||||
t.Fatalf("unexpected signed query: %q", req.URL.RawQuery)
|
||||
}
|
||||
}
|
||||
@@ -110,19 +110,26 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
|
||||
if r.From != "" {
|
||||
slog.Debug("create model from model name", "from", r.From)
|
||||
fromName := model.ParseName(r.From)
|
||||
if !fromName.IsValid() {
|
||||
fromRef, err := parseAndValidateModelRef(r.From)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": errtypes.InvalidModelNameErrMsg, "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
if r.RemoteHost != "" {
|
||||
ru, err := remoteURL(r.RemoteHost)
|
||||
|
||||
fromName := fromRef.Name
|
||||
remoteHost := r.RemoteHost
|
||||
if fromRef.Source == modelSourceCloud && remoteHost == "" {
|
||||
remoteHost = cloudProxyBaseURL
|
||||
}
|
||||
|
||||
if remoteHost != "" {
|
||||
ru, err := remoteURL(remoteHost)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": "bad remote", "status": http.StatusBadRequest}
|
||||
return
|
||||
}
|
||||
|
||||
config.RemoteModel = r.From
|
||||
config.RemoteModel = fromRef.Base
|
||||
config.RemoteHost = ru
|
||||
remote = true
|
||||
} else {
|
||||
|
||||
81
server/model_resolver.go
Normal file
81
server/model_resolver.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/internal/modelref"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
type modelSource = modelref.ModelSource
|
||||
|
||||
const (
|
||||
modelSourceUnspecified modelSource = modelref.ModelSourceUnspecified
|
||||
modelSourceLocal modelSource = modelref.ModelSourceLocal
|
||||
modelSourceCloud modelSource = modelref.ModelSourceCloud
|
||||
)
|
||||
|
||||
var (
|
||||
errConflictingModelSource = modelref.ErrConflictingSourceSuffix
|
||||
errModelRequired = modelref.ErrModelRequired
|
||||
)
|
||||
|
||||
type parsedModelRef struct {
|
||||
// Original is the caller-provided model string before source parsing.
|
||||
// Example: "gpt-oss:20b:cloud".
|
||||
Original string
|
||||
// Base is the model string after source suffix normalization.
|
||||
// Example: "gpt-oss:20b:cloud" -> "gpt-oss:20b".
|
||||
Base string
|
||||
// Name is Base parsed as a fully-qualified model.Name with defaults applied.
|
||||
// Example: "registry.ollama.ai/library/gpt-oss:20b".
|
||||
Name model.Name
|
||||
// Source captures explicit source intent from the original input.
|
||||
// Example: "gpt-oss:20b:cloud" -> modelSourceCloud.
|
||||
Source modelSource
|
||||
}
|
||||
|
||||
func parseAndValidateModelRef(raw string) (parsedModelRef, error) {
|
||||
var zero parsedModelRef
|
||||
|
||||
parsed, err := modelref.ParseRef(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
name := model.ParseName(parsed.Base)
|
||||
if !name.IsValid() {
|
||||
return zero, model.Unqualified(name)
|
||||
}
|
||||
|
||||
return parsedModelRef{
|
||||
Original: parsed.Original,
|
||||
Base: parsed.Base,
|
||||
Name: name,
|
||||
Source: parsed.Source,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseNormalizePullModelRef(raw string) (parsedModelRef, error) {
|
||||
var zero parsedModelRef
|
||||
|
||||
parsedRef, err := modelref.ParseRef(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
normalizedName, _, err := modelref.NormalizePullName(raw)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
name := model.ParseName(normalizedName)
|
||||
if !name.IsValid() {
|
||||
return zero, model.Unqualified(name)
|
||||
}
|
||||
|
||||
return parsedModelRef{
|
||||
Original: parsedRef.Original,
|
||||
Base: normalizedName,
|
||||
Name: name,
|
||||
Source: parsedRef.Source,
|
||||
}, nil
|
||||
}
|
||||
170
server/model_resolver_test.go
Normal file
170
server/model_resolver_test.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseModelSelector(t *testing.T) {
|
||||
t.Run("cloud suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:20b:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceCloud {
|
||||
t.Fatalf("expected source cloud, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
|
||||
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy cloud suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:20b-cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceCloud {
|
||||
t.Fatalf("expected source cloud, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("bare dash cloud name is not explicit cloud", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("my-cloud-model")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "my-cloud-model" {
|
||||
t.Fatalf("expected base my-cloud-model, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("local suffix", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("qwen3:8b:local")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceLocal {
|
||||
t.Fatalf("expected source local, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "qwen3:8b" {
|
||||
t.Fatalf("expected base qwen3:8b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("conflicting source suffixes fail", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("foo:cloud:local")
|
||||
if !errors.Is(err, errConflictingModelSource) {
|
||||
t.Fatalf("expected errConflictingModelSource, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unspecified source", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("llama3")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Name.Tag != "latest" {
|
||||
t.Fatalf("expected default latest tag, got %q", got.Name.Tag)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("unknown suffix is treated as tag", func(t *testing.T) {
|
||||
got, err := parseAndValidateModelRef("gpt-oss:clod")
|
||||
if err != nil {
|
||||
t.Fatalf("parseModelSelector returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceUnspecified {
|
||||
t.Fatalf("expected source unspecified, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Name.Tag != "clod" {
|
||||
t.Fatalf("expected tag clod, got %q", got.Name.Tag)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("empty model fails", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("")
|
||||
if !errors.Is(err, errModelRequired) {
|
||||
t.Fatalf("expected errModelRequired, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("invalid model fails", func(t *testing.T) {
|
||||
_, err := parseAndValidateModelRef("::cloud")
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid model")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "unqualified") {
|
||||
t.Fatalf("expected unqualified model error, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestParsePullModelRef(t *testing.T) {
|
||||
t.Run("explicit local is normalized", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("gpt-oss:20b:local")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
|
||||
if got.Source != modelSourceLocal {
|
||||
t.Fatalf("expected source local, got %v", got.Source)
|
||||
}
|
||||
|
||||
if got.Base != "gpt-oss:20b" {
|
||||
t.Fatalf("expected base gpt-oss:20b, got %q", got.Base)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit cloud with size maps to legacy cloud suffix", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("gpt-oss:20b:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
if got.Base != "gpt-oss:20b-cloud" {
|
||||
t.Fatalf("expected base gpt-oss:20b-cloud, got %q", got.Base)
|
||||
}
|
||||
if got.Name.String() != "registry.ollama.ai/library/gpt-oss:20b-cloud" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("explicit cloud without size maps to cloud tag", func(t *testing.T) {
|
||||
got, err := parseNormalizePullModelRef("qwen3:cloud")
|
||||
if err != nil {
|
||||
t.Fatalf("parseNormalizePullModelRef returned error: %v", err)
|
||||
}
|
||||
if got.Base != "qwen3:cloud" {
|
||||
t.Fatalf("expected base qwen3:cloud, got %q", got.Base)
|
||||
}
|
||||
if got.Name.String() != "registry.ollama.ai/library/qwen3:cloud" {
|
||||
t.Fatalf("unexpected resolved name: %q", got.Name.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
150
server/routes.go
150
server/routes.go
@@ -64,6 +64,17 @@ const (
|
||||
cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
|
||||
)
|
||||
|
||||
func writeModelRefParseError(c *gin.Context, err error, fallbackStatus int, fallbackMessage string) {
|
||||
switch {
|
||||
case errors.Is(err, errConflictingModelSource):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
case errors.Is(err, model.ErrUnqualifiedName):
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
default:
|
||||
c.JSON(fallbackStatus, gin.H{"error": fallbackMessage})
|
||||
}
|
||||
}
|
||||
|
||||
func shouldUseHarmony(model *Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
@@ -196,14 +207,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
// Ideally this is "invalid model name" but we're keeping with
|
||||
// what the API currently returns until we can change it.
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
// TODO(drifkin): evaluate an `/api/*` passthrough for cloud where the
|
||||
// original body (modulo model name normalization) is sent to cloud.
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -237,6 +256,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
if disabled, _ := internalcloud.Status(); disabled {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
|
||||
@@ -676,6 +700,18 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusNotFound, fmt.Sprintf("model '%s' not found", req.Model))
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
var input []string
|
||||
|
||||
switch i := req.Input.(type) {
|
||||
@@ -698,7 +734,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
name, err := getExistingName(model.ParseName(req.Model))
|
||||
name, err := getExistingName(modelRef.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
@@ -845,12 +881,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
@@ -892,12 +936,19 @@ func (s *Server) PullHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(cmp.Or(req.Model, req.Name))
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
// TEMP(drifkin): we're temporarily allowing to continue pulling cloud model
|
||||
// stub-files until we integrate cloud models into `/api/tags` (in which case
|
||||
// this roundabout way of "adding" cloud models won't be needed anymore). So
|
||||
// right here normalize any `:cloud` models into the legacy-style suffixes
|
||||
// `:<tag>-cloud` and `:cloud`
|
||||
modelRef, err := parseNormalizePullModelRef(cmp.Or(req.Model, req.Name))
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, errtypes.InvalidModelNameErrMsg)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
name, err = getExistingName(name)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -1024,13 +1075,20 @@ func (s *Server) DeleteHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
n := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
if !n.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
||||
modelRef, err := parseNormalizePullModelRef(cmp.Or(r.Model, r.Name))
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, errConflictingModelSource):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
case errors.Is(err, model.ErrUnqualifiedName):
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))})
|
||||
default:
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
n, err := getExistingName(n)
|
||||
n, err := getExistingName(modelRef.Name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", cmp.Or(r.Model, r.Name))})
|
||||
return
|
||||
@@ -1079,6 +1137,20 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteModelDetailsUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
req.Model = modelRef.Base
|
||||
|
||||
resp, err := GetModelInfo(req)
|
||||
if err != nil {
|
||||
var statusErr api.StatusError
|
||||
@@ -1095,6 +1167,11 @@ func (s *Server) ShowHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceLocal && resp.RemoteHost != "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", modelRef.Original)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
@@ -1631,18 +1708,20 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
// TODO(cloud-stage-a): apply Modelfile overlay deltas for local models with cloud
|
||||
// parents on v1 request families while preserving this explicit :cloud passthrough.
|
||||
r.POST("/v1/chat/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
r.GET("/v1/models/:model", cloudModelPathPassthroughMiddleware(cloudErrRemoteModelDetailsUnavailable), middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/responses", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ResponsesMiddleware(), s.ChatHandler)
|
||||
// OpenAI-compatible image generation endpoints
|
||||
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/generations", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/images/edits", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.ImageEditsMiddleware(), s.GenerateHandler)
|
||||
|
||||
// Inference (Anthropic compatibility)
|
||||
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/messages", cloudPassthroughMiddleware(cloudErrRemoteInferenceUnavailable), middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
@@ -2001,12 +2080,24 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(req.Model)
|
||||
if !name.IsValid() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
|
||||
modelRef, err := parseAndValidateModelRef(req.Model)
|
||||
if err != nil {
|
||||
writeModelRefParseError(c, err, http.StatusBadRequest, "model is required")
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceCloud {
|
||||
req.Model = modelRef.Base
|
||||
if c.GetBool(legacyCloudAnthropicKey) {
|
||||
proxyCloudJSONRequestWithPath(c, req, "/api/chat", cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
proxyCloudJSONRequest(c, req, cloudErrRemoteInferenceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
name := modelRef.Name
|
||||
|
||||
resolvedName, _, err := s.resolveAlias(name)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -2038,6 +2129,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if modelRef.Source == modelSourceLocal && m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
|
||||
return
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -794,6 +794,43 @@ func TestCreateAndShowRemoteModel(t *testing.T) {
|
||||
fmt.Printf("resp = %#v\n", resp)
|
||||
}
|
||||
|
||||
func TestCreateFromCloudSourceSuffix(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-cloud-from-suffix",
|
||||
From: "gpt-oss:20b:cloud",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.ShowHandler, api.ShowRequest{Model: "test-cloud-from-suffix"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ShowResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.RemoteHost != "https://ollama.com:443" {
|
||||
t.Fatalf("expected remote host https://ollama.com:443, got %q", resp.RemoteHost)
|
||||
}
|
||||
|
||||
if resp.RemoteModel != "gpt-oss:20b" {
|
||||
t.Fatalf("expected remote model gpt-oss:20b, got %q", resp.RemoteModel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateLicenses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -111,3 +111,32 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
}
|
||||
|
||||
func TestDeleteCloudSourceNormalizesToLegacyName(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
|
||||
var s Server
|
||||
|
||||
_, digest := createBinFile(t, nil, nil)
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "gpt-oss:20b-cloud",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "gpt-oss", "20b-cloud"),
|
||||
})
|
||||
|
||||
w = createRequest(t, s.DeleteHandler, api.DeleteRequest{Name: "gpt-oss:20b:cloud"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d (%s)", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user