diff --git a/cmd/cmd.go b/cmd/cmd.go index 6ca13f1af..702bbc1aa 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -2260,7 +2260,7 @@ func NewCLI() *cobra.Command { switch cmd { case runCmd: imagegen.AppendFlagsDocs(cmd) - appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]}) + appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_EDITOR"], envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]}) case serveCmd: appendEnvDocs(cmd, []envconfig.EnvVar{ envVars["OLLAMA_DEBUG"], diff --git a/cmd/editor_unix.go b/cmd/editor_unix.go new file mode 100644 index 000000000..0a7848c83 --- /dev/null +++ b/cmd/editor_unix.go @@ -0,0 +1,5 @@ +//go:build !windows + +package cmd + +const defaultEditor = "vi" diff --git a/cmd/editor_windows.go b/cmd/editor_windows.go new file mode 100644 index 000000000..ed428859d --- /dev/null +++ b/cmd/editor_windows.go @@ -0,0 +1,5 @@ +//go:build windows + +package cmd + +const defaultEditor = "edit" diff --git a/cmd/interactive.go b/cmd/interactive.go index a3c98c68e..1f91f9eca 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "os" + "os/exec" "path/filepath" "regexp" "slices" @@ -79,6 +80,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Fprintln(os.Stderr, " Ctrl + w Delete the word before the cursor") fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen") + fmt.Fprintln(os.Stderr, " Ctrl + g Open default editor to compose a prompt") fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding") fmt.Fprintln(os.Stderr, " Ctrl + d Exit ollama (/bye)") fmt.Fprintln(os.Stderr, "") @@ -147,6 +149,18 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { scanner.Prompt.UseAlt = false sb.Reset() + continue + case errors.Is(err, readline.ErrEditPrompt): + sb.Reset() + content, err := editInExternalEditor(line) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + continue + } + if strings.TrimSpace(content) == "" { + continue + } + scanner.Prefill = content continue case err != nil: return err @@ -598,6 +612,57 @@ func extractFileData(input string) (string, []api.ImageData, error) { return strings.TrimSpace(input), imgs, nil } +func editInExternalEditor(content string) (string, error) { + editor := envconfig.Editor() + if editor == "" { + editor = os.Getenv("VISUAL") + } + if editor == "" { + editor = os.Getenv("EDITOR") + } + if editor == "" { + editor = defaultEditor + } + + // Check that the editor binary exists + name := strings.Fields(editor)[0] + if _, err := exec.LookPath(name); err != nil { + return "", fmt.Errorf("editor %q not found, set OLLAMA_EDITOR to the path of your preferred editor", name) + } + + tmpFile, err := os.CreateTemp("", "ollama-prompt-*.txt") + if err != nil { + return "", fmt.Errorf("creating temp file: %w", err) + } + defer os.Remove(tmpFile.Name()) + + if content != "" { + if _, err := tmpFile.WriteString(content); err != nil { + tmpFile.Close() + return "", fmt.Errorf("writing to temp file: %w", err) + } + } + tmpFile.Close() + + args := strings.Fields(editor) + args = append(args, tmpFile.Name()) + cmd := exec.Command(args[0], args[1:]...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("editor exited with error: %w", err) + } + + data, err := os.ReadFile(tmpFile.Name()) + if err != nil { + return "", fmt.Errorf("reading temp file: %w", err) + } + + return strings.TrimRight(string(data), "\n"), nil +} + func getImageData(filePath string) ([]byte, error) { file, err := os.Open(filePath) if err != nil { diff --git a/envconfig/config.go b/envconfig/config.go index a7466a84a..96886237c 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -216,6 +216,7 @@ func String(s string) func() string { var ( LLMLibrary = String("OLLAMA_LLM_LIBRARY") + Editor = String("OLLAMA_EDITOR") CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES") HipVisibleDevices = String("HIP_VISIBLE_DEVICES") @@ -291,6 +292,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"}, + "OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, "OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"}, diff --git a/readline/errors.go b/readline/errors.go index bb3fbd473..1be5213e5 100644 --- a/readline/errors.go +++ b/readline/errors.go @@ -5,6 +5,7 @@ import ( ) var ErrInterrupt = errors.New("Interrupt") +var ErrEditPrompt = errors.New("EditPrompt") type InterruptError struct { Line []rune diff --git a/readline/readline.go b/readline/readline.go index 42464da50..0113aa2ce 100644 --- a/readline/readline.go +++ b/readline/readline.go @@ -41,6 +41,7 @@ type Instance struct { Terminal *Terminal History *History Pasting bool + Prefill string pastedLines []string } @@ -89,6 +90,27 @@ func (i *Instance) Readline() (string, error) { buf, _ := NewBuffer(i.Prompt) + // Prefill the buffer with any text that we received from an external editor + if i.Prefill != "" { + lines := strings.Split(i.Prefill, "\n") + i.Prefill = "" + for idx, l := range lines { + for _, r := range l { + buf.Add(r) + } + if idx < len(lines)-1 { + i.pastedLines = append(i.pastedLines, buf.String()) + buf.Buf.Clear() + buf.Pos = 0 + buf.DisplayPos = 0 + buf.LineHasSpace.Clear() + fmt.Println() + fmt.Print(i.Prompt.AltPrompt) + i.Prompt.UseAlt = true + } + } + } + var esc bool var escex bool var metaDel bool @@ -251,6 +273,29 @@ func (i *Instance) Readline() (string, error) { buf.ClearScreen() case CharCtrlW: buf.DeleteWord() + case CharBell: + output := buf.String() + numPastedLines := len(i.pastedLines) + if numPastedLines > 0 { + output = strings.Join(i.pastedLines, "\n") + "\n" + output + i.pastedLines = nil + } + + // Move cursor to the last display line of the current buffer + currLine := buf.DisplayPos / buf.LineWidth + lastLine := buf.DisplaySize() / buf.LineWidth + if lastLine > currLine { + fmt.Print(CursorDownN(lastLine - currLine)) + } + + // Clear all lines from bottom to top: buffer wrapped lines + pasted lines + for range lastLine + numPastedLines { + fmt.Print(CursorBOL + ClearToEOL + CursorUp) + } + fmt.Print(CursorBOL + ClearToEOL) + + i.Prompt.UseAlt = false + return output, ErrEditPrompt case CharCtrlZ: fd := os.Stdin.Fd() return handleCharCtrlZ(fd, i.Terminal.termios)