mirror of
https://github.com/ollama/ollama.git
synced 2025-12-05 18:46:22 -06:00
embeddings: added embedding command for cl (#12795)
Co-authored-by: A-Akhil <akhilrahul70@gmail.com> This PR introduces a new ollama embed command that allows users to generate embeddings directly from the command line. Added ollama embed MODEL [TEXT...] command for generating text embeddings Supports both direct text arguments and stdin piping for scripted workflows Outputs embeddings as JSON arrays (one per line)
This commit is contained in:
12
README.md
12
README.md
@@ -226,6 +226,18 @@ ollama ps
|
||||
ollama stop llama3.2
|
||||
```
|
||||
|
||||
### Generate embeddings from the CLI
|
||||
|
||||
```shell
|
||||
ollama run embeddinggemma "Your text to embed"
|
||||
```
|
||||
|
||||
You can also pipe text for scripted workflows:
|
||||
|
||||
```shell
|
||||
echo "Your text to embed" | ollama run embeddinggemma
|
||||
```
|
||||
|
||||
### Start Ollama
|
||||
|
||||
`ollama serve` is used when you want to start ollama without running the desktop application.
|
||||
|
||||
69
cmd/cmd.go
69
cmd/cmd.go
@@ -322,6 +322,44 @@ func StopHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateEmbedding(cmd *cobra.Command, modelName, input string, keepAlive *api.Duration, truncate *bool, dimensions int) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &api.EmbedRequest{
|
||||
Model: modelName,
|
||||
Input: input,
|
||||
}
|
||||
if keepAlive != nil {
|
||||
req.KeepAlive = keepAlive
|
||||
}
|
||||
if truncate != nil {
|
||||
req.Truncate = truncate
|
||||
}
|
||||
if dimensions > 0 {
|
||||
req.Dimensions = dimensions
|
||||
}
|
||||
|
||||
resp, err := client.Embed(cmd.Context(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(resp.Embeddings) == 0 {
|
||||
return errors.New("no embeddings returned")
|
||||
}
|
||||
|
||||
output, err := json.Marshal(resp.Embeddings[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println(string(output))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
@@ -386,7 +424,11 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
prompts = append([]string{string(in)}, prompts...)
|
||||
// Only prepend stdin content if it's not empty
|
||||
stdinContent := string(in)
|
||||
if len(stdinContent) > 0 {
|
||||
prompts = append([]string{stdinContent}, prompts...)
|
||||
}
|
||||
opts.ShowConnect = false
|
||||
opts.WordWrap = false
|
||||
interactive = false
|
||||
@@ -452,6 +494,29 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
opts.ParentModel = info.Details.ParentModel
|
||||
|
||||
// Check if this is an embedding model
|
||||
isEmbeddingModel := slices.Contains(info.Capabilities, model.CapabilityEmbedding)
|
||||
|
||||
// If it's an embedding model, handle embedding generation
|
||||
if isEmbeddingModel {
|
||||
if opts.Prompt == "" {
|
||||
return errors.New("embedding models require input text. Usage: ollama run " + name + " \"your text here\"")
|
||||
}
|
||||
|
||||
// Get embedding-specific flags
|
||||
var truncate *bool
|
||||
if truncateFlag, err := cmd.Flags().GetBool("truncate"); err == nil && cmd.Flags().Changed("truncate") {
|
||||
truncate = &truncateFlag
|
||||
}
|
||||
|
||||
dimensions, err := cmd.Flags().GetInt("dimensions")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions)
|
||||
}
|
||||
|
||||
if interactive {
|
||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||
var sErr api.AuthorizationError
|
||||
@@ -1684,6 +1749,8 @@ func NewCLI() *cobra.Command {
|
||||
runCmd.Flags().String("think", "", "Enable thinking mode: true/false or high/medium/low for supported models")
|
||||
runCmd.Flags().Lookup("think").NoOptDefVal = "true"
|
||||
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||
runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead")
|
||||
runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)")
|
||||
|
||||
stopCmd := &cobra.Command{
|
||||
Use: "stop MODEL",
|
||||
|
||||
324
cmd/cmd_test.go
324
cmd/cmd_test.go
@@ -355,6 +355,330 @@ func TestDeleteHandler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunEmbeddingModel(t *testing.T) {
|
||||
reqCh := make(chan api.EmbedRequest, 1)
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" && r.Method == http.MethodPost {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
Capabilities: []model.Capability{model.CapabilityEmbedding},
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/api/embed" && r.Method == http.MethodPost {
|
||||
var req api.EmbedRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
reqCh <- req
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.EmbedResponse{
|
||||
Model: "test-embedding-model",
|
||||
Embeddings: [][]float32{{0.1, 0.2, 0.3}},
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
oldStdout := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- RunHandler(cmd, []string{"test-embedding-model", "hello", "world"})
|
||||
}()
|
||||
|
||||
err := <-errCh
|
||||
w.Close()
|
||||
os.Stdout = oldStdout
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
io.Copy(&out, r)
|
||||
|
||||
select {
|
||||
case req := <-reqCh:
|
||||
inputText, _ := req.Input.(string)
|
||||
if diff := cmp.Diff("hello world", inputText); diff != "" {
|
||||
t.Errorf("unexpected input (-want +got):\n%s", diff)
|
||||
}
|
||||
if req.Truncate != nil {
|
||||
t.Errorf("expected truncate to be nil, got %v", *req.Truncate)
|
||||
}
|
||||
if req.KeepAlive != nil {
|
||||
t.Errorf("expected keepalive to be nil, got %v", req.KeepAlive)
|
||||
}
|
||||
if req.Dimensions != 0 {
|
||||
t.Errorf("expected dimensions to be 0, got %d", req.Dimensions)
|
||||
}
|
||||
default:
|
||||
t.Fatal("server did not receive embed request")
|
||||
}
|
||||
|
||||
expectOutput := "[0.1,0.2,0.3]\n"
|
||||
if diff := cmp.Diff(expectOutput, out.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunEmbeddingModelWithFlags(t *testing.T) {
|
||||
reqCh := make(chan api.EmbedRequest, 1)
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" && r.Method == http.MethodPost {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
Capabilities: []model.Capability{model.CapabilityEmbedding},
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/api/embed" && r.Method == http.MethodPost {
|
||||
var req api.EmbedRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
reqCh <- req
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.EmbedResponse{
|
||||
Model: "test-embedding-model",
|
||||
Embeddings: [][]float32{{0.4, 0.5}},
|
||||
LoadDuration: 5 * time.Millisecond,
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
if err := cmd.Flags().Set("truncate", "true"); err != nil {
|
||||
t.Fatalf("failed to set truncate flag: %v", err)
|
||||
}
|
||||
if err := cmd.Flags().Set("dimensions", "2"); err != nil {
|
||||
t.Fatalf("failed to set dimensions flag: %v", err)
|
||||
}
|
||||
if err := cmd.Flags().Set("keepalive", "5m"); err != nil {
|
||||
t.Fatalf("failed to set keepalive flag: %v", err)
|
||||
}
|
||||
|
||||
oldStdout := os.Stdout
|
||||
r, w, _ := os.Pipe()
|
||||
os.Stdout = w
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- RunHandler(cmd, []string{"test-embedding-model", "test", "input"})
|
||||
}()
|
||||
|
||||
err := <-errCh
|
||||
w.Close()
|
||||
os.Stdout = oldStdout
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
io.Copy(&out, r)
|
||||
|
||||
select {
|
||||
case req := <-reqCh:
|
||||
inputText, _ := req.Input.(string)
|
||||
if diff := cmp.Diff("test input", inputText); diff != "" {
|
||||
t.Errorf("unexpected input (-want +got):\n%s", diff)
|
||||
}
|
||||
if req.Truncate == nil || !*req.Truncate {
|
||||
t.Errorf("expected truncate pointer true, got %v", req.Truncate)
|
||||
}
|
||||
if req.Dimensions != 2 {
|
||||
t.Errorf("expected dimensions 2, got %d", req.Dimensions)
|
||||
}
|
||||
if req.KeepAlive == nil || req.KeepAlive.Duration != 5*time.Minute {
|
||||
t.Errorf("unexpected keepalive duration: %v", req.KeepAlive)
|
||||
}
|
||||
default:
|
||||
t.Fatal("server did not receive embed request")
|
||||
}
|
||||
|
||||
expectOutput := "[0.4,0.5]\n"
|
||||
if diff := cmp.Diff(expectOutput, out.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunEmbeddingModelPipedInput(t *testing.T) {
|
||||
reqCh := make(chan api.EmbedRequest, 1)
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" && r.Method == http.MethodPost {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
Capabilities: []model.Capability{model.CapabilityEmbedding},
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
if r.URL.Path == "/api/embed" && r.Method == http.MethodPost {
|
||||
var req api.EmbedRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
reqCh <- req
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.EmbedResponse{
|
||||
Model: "test-embedding-model",
|
||||
Embeddings: [][]float32{{0.6, 0.7}},
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
// Capture stdin
|
||||
oldStdin := os.Stdin
|
||||
stdinR, stdinW, _ := os.Pipe()
|
||||
os.Stdin = stdinR
|
||||
stdinW.Write([]byte("piped text"))
|
||||
stdinW.Close()
|
||||
|
||||
// Capture stdout
|
||||
oldStdout := os.Stdout
|
||||
stdoutR, stdoutW, _ := os.Pipe()
|
||||
os.Stdout = stdoutW
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- RunHandler(cmd, []string{"test-embedding-model", "additional", "args"})
|
||||
}()
|
||||
|
||||
err := <-errCh
|
||||
stdoutW.Close()
|
||||
os.Stdout = oldStdout
|
||||
os.Stdin = oldStdin
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("RunHandler returned error: %v", err)
|
||||
}
|
||||
|
||||
var out bytes.Buffer
|
||||
io.Copy(&out, stdoutR)
|
||||
|
||||
select {
|
||||
case req := <-reqCh:
|
||||
inputText, _ := req.Input.(string)
|
||||
// Should combine piped input with command line args
|
||||
if diff := cmp.Diff("piped text additional args", inputText); diff != "" {
|
||||
t.Errorf("unexpected input (-want +got):\n%s", diff)
|
||||
}
|
||||
default:
|
||||
t.Fatal("server did not receive embed request")
|
||||
}
|
||||
|
||||
expectOutput := "[0.6,0.7]\n"
|
||||
if diff := cmp.Diff(expectOutput, out.String()); diff != "" {
|
||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunEmbeddingModelNoInput(t *testing.T) {
|
||||
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/show" && r.Method == http.MethodPost {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(api.ShowResponse{
|
||||
Capabilities: []model.Capability{model.CapabilityEmbedding},
|
||||
}); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
|
||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||
t.Cleanup(mockServer.Close)
|
||||
|
||||
cmd := &cobra.Command{}
|
||||
cmd.SetContext(t.Context())
|
||||
cmd.Flags().String("keepalive", "", "")
|
||||
cmd.Flags().Bool("truncate", false, "")
|
||||
cmd.Flags().Int("dimensions", 0, "")
|
||||
cmd.Flags().Bool("verbose", false, "")
|
||||
cmd.Flags().Bool("insecure", false, "")
|
||||
cmd.Flags().Bool("nowordwrap", false, "")
|
||||
cmd.Flags().String("format", "", "")
|
||||
cmd.Flags().String("think", "", "")
|
||||
cmd.Flags().Bool("hidethinking", false, "")
|
||||
|
||||
cmd.SetOut(io.Discard)
|
||||
cmd.SetErr(io.Discard)
|
||||
|
||||
// Test with no input arguments (only model name)
|
||||
err := RunHandler(cmd, []string{"test-embedding-model"})
|
||||
if err == nil || !strings.Contains(err.Error(), "embedding models require input text") {
|
||||
t.Fatalf("expected error about missing input, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelfileName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
12
docs/cli.mdx
12
docs/cli.mdx
@@ -25,6 +25,18 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol
|
||||
ollama run gemma3 "What's in this image? /Users/jmorgan/Desktop/smile.png"
|
||||
```
|
||||
|
||||
### Generate embeddings
|
||||
|
||||
```
|
||||
ollama run embeddinggemma "Hello world"
|
||||
```
|
||||
|
||||
Output is a JSON array:
|
||||
|
||||
```
|
||||
echo "Hello world" | ollama run nomic-embed-text
|
||||
```
|
||||
|
||||
### Download a model
|
||||
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user