mirror of
https://github.com/ollama/ollama.git
synced 2026-04-30 17:58:49 -05:00
fix: use api.GenerateRequest for image generation test (#13793)
Remove non-existent x/imagegen/api import and use the standard api.GenerateRequest/GenerateResponse with the Image field instead.
This commit is contained in:
@@ -3,18 +3,14 @@
|
|||||||
package integration
|
package integration
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
imagegenapi "github.com/ollama/ollama/x/imagegen/api"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestImageGeneration(t *testing.T) {
|
func TestImageGeneration(t *testing.T) {
|
||||||
@@ -41,7 +37,7 @@ func TestImageGeneration(t *testing.T) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
client, testEndpoint, cleanup := InitServerConnection(ctx, t)
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
// Pull both models
|
// Pull both models
|
||||||
@@ -54,7 +50,7 @@ func TestImageGeneration(t *testing.T) {
|
|||||||
|
|
||||||
// Generate the image
|
// Generate the image
|
||||||
t.Logf("Generating image with prompt: %s", tc.prompt)
|
t.Logf("Generating image with prompt: %s", tc.prompt)
|
||||||
imageBase64, err := generateImage(ctx, testEndpoint, tc.imageGenModel, tc.prompt)
|
imageBase64, err := generateImage(ctx, client, tc.imageGenModel, tc.prompt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if strings.Contains(err.Error(), "image generation not available") {
|
if strings.Contains(err.Error(), "image generation not available") {
|
||||||
t.Skip("Target system does not support image generation")
|
t.Skip("Target system does not support image generation")
|
||||||
@@ -127,48 +123,26 @@ func TestImageGeneration(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateImage calls the OpenAI-compatible image generation API and returns the base64 image data
|
// generateImage calls the Ollama API to generate an image and returns the base64 image data
|
||||||
func generateImage(ctx context.Context, endpoint, model, prompt string) (string, error) {
|
func generateImage(ctx context.Context, client *api.Client, model, prompt string) (string, error) {
|
||||||
reqBody := imagegenapi.ImageGenerationRequest{
|
var imageBase64 string
|
||||||
Model: model,
|
|
||||||
Prompt: prompt,
|
|
||||||
N: 1,
|
|
||||||
Size: "512x512",
|
|
||||||
ResponseFormat: "b64_json",
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonBody, err := json.Marshal(reqBody)
|
err := client.Generate(ctx, &api.GenerateRequest{
|
||||||
|
Model: model,
|
||||||
|
Prompt: prompt,
|
||||||
|
}, func(resp api.GenerateResponse) error {
|
||||||
|
if resp.Image != "" {
|
||||||
|
imageBase64 = resp.Image
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("failed to marshal request: %w", err)
|
return "", fmt.Errorf("failed to generate image: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("http://%s/v1/images/generations", endpoint)
|
if imageBase64 == "" {
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to create request: %w", err)
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
|
|
||||||
resp, err := http.DefaultClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to send request: %w", err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
buf.ReadFrom(resp.Body)
|
|
||||||
return "", fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, buf.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
var genResp imagegenapi.ImageGenerationResponse
|
|
||||||
if err := json.NewDecoder(resp.Body).Decode(&genResp); err != nil {
|
|
||||||
return "", fmt.Errorf("failed to decode response: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(genResp.Data) == 0 {
|
|
||||||
return "", fmt.Errorf("no image data in response")
|
return "", fmt.Errorf("no image data in response")
|
||||||
}
|
}
|
||||||
|
|
||||||
return genResp.Data[0].B64JSON, nil
|
return imageBase64, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user