diff --git a/x/imagegen/transfer/download.go b/x/imagegen/transfer/download.go index 5395c0c22..97738a870 100644 --- a/x/imagegen/transfer/download.go +++ b/x/imagegen/transfer/download.go @@ -45,24 +45,33 @@ func download(ctx context.Context, opts DownloadOptions) error { return nil } - // Filter existing - var blobs []Blob + // Calculate total from all blobs (for accurate progress reporting on resume) var total int64 + for _, b := range opts.Blobs { + total += b.Size + } + + // Filter out already-downloaded blobs and track completed bytes + var blobs []Blob + var alreadyCompleted int64 for _, b := range opts.Blobs { if fi, _ := os.Stat(filepath.Join(opts.DestDir, digestToPath(b.Digest))); fi != nil && fi.Size() == b.Size { if opts.Logger != nil { opts.Logger.Debug("blob already exists", "digest", b.Digest, "size", b.Size) } + alreadyCompleted += b.Size continue } blobs = append(blobs, b) - total += b.Size } if len(blobs) == 0 { return nil } token := opts.Token + progress := newProgressTracker(total, opts.Progress) + progress.add(alreadyCompleted) // Report already-downloaded bytes upfront + d := &downloader{ client: cmp.Or(opts.Client, defaultClient), baseURL: opts.BaseURL, @@ -72,7 +81,7 @@ func download(ctx context.Context, opts DownloadOptions) error { getToken: opts.GetToken, userAgent: cmp.Or(opts.UserAgent, defaultUserAgent), stallTimeout: cmp.Or(opts.StallTimeout, defaultStallTimeout), - progress: newProgressTracker(total, opts.Progress), + progress: progress, speeds: &speedTracker{}, logger: opts.Logger, } diff --git a/x/imagegen/transfer/transfer.go b/x/imagegen/transfer/transfer.go index fa1ff8a11..05842f065 100644 --- a/x/imagegen/transfer/transfer.go +++ b/x/imagegen/transfer/transfer.go @@ -110,8 +110,6 @@ var defaultClient = &http.Client{ MaxIdleConnsPerHost: 100, IdleConnTimeout: 90 * time.Second, }, - Timeout: 5 * time.Minute, - // Don't follow redirects automatically - we handle them manually CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, diff --git a/x/imagegen/transfer/transfer_test.go b/x/imagegen/transfer/transfer_test.go index 3218577bf..78e386ed9 100644 --- a/x/imagegen/transfer/transfer_test.go +++ b/x/imagegen/transfer/transfer_test.go @@ -284,6 +284,83 @@ func TestDownloadSkipsExisting(t *testing.T) { } } +func TestDownloadResumeProgressTotal(t *testing.T) { + // Test that when resuming a download with some blobs already present: + // 1. Total reflects ALL blob sizes (not just remaining) + // 2. Completed starts at the size of already-downloaded blobs + serverDir := t.TempDir() + blob1, data1 := createTestBlob(t, serverDir, 1000) + blob2, data2 := createTestBlob(t, serverDir, 2000) + blob3, data3 := createTestBlob(t, serverDir, 3000) + + // Pre-populate client with blob1 and blob2 (simulating partial download) + clientDir := t.TempDir() + for _, b := range []struct { + blob Blob + data []byte + }{{blob1, data1}, {blob2, data2}} { + path := filepath.Join(clientDir, digestToPath(b.blob.Digest)) + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, b.data, 0o644); err != nil { + t.Fatal(err) + } + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + digest := filepath.Base(r.URL.Path) + path := filepath.Join(serverDir, digestToPath(digest)) + data, err := os.ReadFile(path) + if err != nil { + http.NotFound(w, r) + return + } + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(data))) + w.WriteHeader(http.StatusOK) + w.Write(data) + })) + defer server.Close() + + var firstCompleted, firstTotal int64 + var gotFirstProgress bool + var mu sync.Mutex + + err := Download(context.Background(), DownloadOptions{ + Blobs: []Blob{blob1, blob2, blob3}, + BaseURL: server.URL, + DestDir: clientDir, + Concurrency: 1, + Progress: func(completed, total int64) { + mu.Lock() + defer mu.Unlock() + if !gotFirstProgress { + firstCompleted = completed + firstTotal = total + gotFirstProgress = true + } + }, + }) + if err != nil { + t.Fatalf("Download failed: %v", err) + } + + // Total should be sum of ALL blobs, not just blob3 + expectedTotal := blob1.Size + blob2.Size + blob3.Size + if firstTotal != expectedTotal { + t.Errorf("Total = %d, want %d (should include all blobs)", firstTotal, expectedTotal) + } + + // First progress call should show already-completed bytes from blob1+blob2 + expectedCompleted := blob1.Size + blob2.Size + if firstCompleted < expectedCompleted { + t.Errorf("First completed = %d, want >= %d (should include already-downloaded blobs)", firstCompleted, expectedCompleted) + } + + // Verify blob3 was downloaded + verifyBlob(t, clientDir, blob3, data3) +} + func TestDownloadDigestMismatch(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Return wrong data