diff --git a/cmd/cmd.go b/cmd/cmd.go index 20da8cd2c..f3fca9366 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -132,6 +132,17 @@ func getModelfileName(cmd *cobra.Command) (string, error) { return absName, nil } +// isLocalhost returns true if the configured Ollama host is a loopback or unspecified address. +func isLocalhost() bool { + host := envconfig.Host() + h, _, _ := net.SplitHostPort(host.Host) + if h == "localhost" { + return true + } + ip := net.ParseIP(h) + return ip != nil && (ip.IsLoopback() || ip.IsUnspecified()) +} + func CreateHandler(cmd *cobra.Command, args []string) error { p := progress.NewProgress(os.Stderr) defer p.Stop() @@ -146,10 +157,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { // Check for --experimental flag for safetensors model creation experimental, _ := cmd.Flags().GetBool("experimental") if experimental { - host := envconfig.Host() - h, _, _ := net.SplitHostPort(host.Host) - ip := net.ParseIP(h) - if ip == nil || (!ip.IsLoopback() && !ip.IsUnspecified()) { + if !isLocalhost() { return errors.New("remote safetensor model creation not yet supported") } // Get Modelfile content - either from -f flag or default to "FROM ." @@ -221,10 +229,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error { if filename == "" { // No Modelfile found - check if current directory is an image gen model if create.IsTensorModelDir(".") { - host := envconfig.Host() - h, _, _ := net.SplitHostPort(host.Host) - ip := net.ParseIP(h) - if ip == nil || (!ip.IsLoopback() && !ip.IsUnspecified()) { + if !isLocalhost() { return errors.New("remote safetensor model creation not yet supported") } quantize, _ := cmd.Flags().GetString("quantize") diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index dfbd63a85..49852c02d 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -1928,3 +1928,38 @@ func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) { }) } } + +func TestIsLocalhost(t *testing.T) { + tests := []struct { + name string + host string + expected bool + }{ + {"default empty", "", true}, + {"localhost no port", "localhost", true}, + {"localhost with port", "localhost:11435", true}, + {"127.0.0.1 no port", "127.0.0.1", true}, + {"127.0.0.1 with port", "127.0.0.1:11434", true}, + {"0.0.0.0 no port", "0.0.0.0", true}, + {"0.0.0.0 with port", "0.0.0.0:11434", true}, + {"::1 no port", "::1", true}, + {"[::1] with port", "[::1]:11434", true}, + {"loopback with scheme", "http://localhost:11434", true}, + {"remote hostname", "example.com", false}, + {"remote hostname with port", "example.com:11434", false}, + {"remote IP", "192.168.1.1", false}, + {"remote IP with port", "192.168.1.1:11434", false}, + {"remote with scheme", "http://example.com:11434", false}, + {"https remote", "https://example.com:443", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("OLLAMA_HOST", tt.host) + got := isLocalhost() + if got != tt.expected { + t.Errorf("isLocalhost() with OLLAMA_HOST=%q = %v, want %v", tt.host, got, tt.expected) + } + }) + } +}