diff --git a/model/renderers/glmocr.go b/model/renderers/glmocr.go index b141da07d..05e7be08e 100644 --- a/model/renderers/glmocr.go +++ b/model/renderers/glmocr.go @@ -8,7 +8,21 @@ import ( "github.com/ollama/ollama/api" ) -type GlmOcrRenderer struct{} +type GlmOcrRenderer struct { + useImgTags bool +} + +func (r *GlmOcrRenderer) renderContent(message api.Message, imageOffset int) (string, int) { + var sb strings.Builder + for range message.Images { + if r.useImgTags { + sb.WriteString(fmt.Sprintf("[img-%d]", imageOffset)) + imageOffset++ + } + } + sb.WriteString(message.Content) + return sb.String(), imageOffset +} func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) { var sb strings.Builder @@ -38,11 +52,14 @@ func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkV thinkingExplicitlySet = true } + imageOffset := 0 for i, message := range messages { switch message.Role { case "user": sb.WriteString("<|user|>\n") - sb.WriteString(message.Content) + content, nextOffset := r.renderContent(message, imageOffset) + imageOffset = nextOffset + sb.WriteString(content) if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") { sb.WriteString("/nothink") } diff --git a/model/renderers/glmocr_test.go b/model/renderers/glmocr_test.go new file mode 100644 index 000000000..dbc611ccb --- /dev/null +++ b/model/renderers/glmocr_test.go @@ -0,0 +1,99 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + + "github.com/ollama/ollama/api" +) + +func TestGlmOcrRenderer_Images(t *testing.T) { + tests := []struct { + name string + renderer *GlmOcrRenderer + messages []api.Message + expected string + }{ + { + name: "use_img_tags_single_image", + renderer: &GlmOcrRenderer{useImgTags: true}, + messages: []api.Message{ + { + Role: "user", + Content: "Describe this image.", + Images: []api.ImageData{api.ImageData("img1")}, + }, + }, + expected: "[gMASK]<|user|>\n[img-0]Describe this image.<|assistant|>\n", + }, + { + name: "use_img_tags_multiple_images", + renderer: &GlmOcrRenderer{useImgTags: true}, + messages: []api.Message{ + { + Role: "user", + Content: "Describe these images.", + Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")}, + }, + }, + expected: "[gMASK]<|user|>\n[img-0][img-1]Describe these images.<|assistant|>\n", + }, + { + name: "multi_turn_increments_image_offset", + renderer: &GlmOcrRenderer{useImgTags: true}, + messages: []api.Message{ + { + Role: "user", + Content: "First image", + Images: []api.ImageData{api.ImageData("img1")}, + }, + { + Role: "assistant", + Content: "Processed.", + }, + { + Role: "user", + Content: "Second image", + Images: []api.ImageData{api.ImageData("img2")}, + }, + }, + expected: "[gMASK]<|user|>\n[img-0]First image<|assistant|>\n\nProcessed.\n<|user|>\n[img-1]Second image<|assistant|>\n", + }, + { + name: "default_no_img_tags", + renderer: &GlmOcrRenderer{}, + messages: []api.Message{ + { + Role: "user", + Content: "No image tags expected.", + Images: []api.ImageData{api.ImageData("img1")}, + }, + }, + expected: "[gMASK]<|user|>\nNo image tags expected.<|assistant|>\n", + }, + { + name: "no_images_content_unchanged", + renderer: &GlmOcrRenderer{useImgTags: true}, + messages: []api.Message{ + { + Role: "user", + Content: "Text only message.", + }, + }, + expected: "[gMASK]<|user|>\nText only message.<|assistant|>\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.renderer.Render(tt.messages, nil, nil) + if err != nil { + t.Fatalf("Render() error = %v", err) + } + if diff := cmp.Diff(tt.expected, got); diff != "" { + t.Fatalf("Render() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go index 43001d9c1..7309d38f5 100644 --- a/model/renderers/renderer.go +++ b/model/renderers/renderer.go @@ -86,7 +86,7 @@ func rendererForName(name string) Renderer { case "glm-4.7": return &GLM47Renderer{} case "glm-ocr": - return &GlmOcrRenderer{} + return &GlmOcrRenderer{useImgTags: RenderImgTags} case "lfm2": return &LFM2Renderer{IsThinking: false, useImgTags: RenderImgTags} case "lfm2-thinking": diff --git a/server/prompt_test.go b/server/prompt_test.go index 024957007..8bbadb22d 100644 --- a/server/prompt_test.go +++ b/server/prompt_test.go @@ -3,6 +3,7 @@ package server import ( "bytes" "context" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -366,3 +367,33 @@ func TestChatPromptRendererDoesNotRewriteMessageContent(t *testing.T) { t.Fatal("prompt is empty") } } + +func TestChatPromptGLMOcrRendererAddsImageTags(t *testing.T) { + msgs := []api.Message{ + { + Role: "user", + Content: "extract text", + Images: []api.ImageData{[]byte("img-1"), []byte("img-2")}, + }, + } + + m := Model{ + Config: model.ConfigV2{Renderer: "glm-ocr"}, + ProjectorPaths: []string{"vision"}, + } + opts := api.Options{Runner: api.Runner{NumCtx: 8192}} + think := false + + prompt, images, err := chatPrompt(t.Context(), &m, mockRunner{}.Tokenize, &opts, msgs, nil, &api.ThinkValue{Value: think}, true) + if err != nil { + t.Fatal(err) + } + + if got, want := len(images), 2; got != want { + t.Fatalf("len(images) = %d, want %d", got, want) + } + + if !strings.Contains(prompt, "<|user|>\n[img-0][img-1]extract text") { + t.Fatalf("prompt missing glm-ocr image tags, got: %q", prompt) + } +}