mirror of
https://github.com/ollama/ollama.git
synced 2026-05-07 00:22:43 -05:00
[GH-ISSUE #5321] Llama3: Generated outputs inconsistent despite seed and temperature #65369
Open
opened 2026-05-03 20:58:26 -05:00 by GiteaMirror
·
22 comments
No Branch/Tag Specified
main
hoyyeva/anthropic-local-image-path
dhiltgen/ci
dhiltgen/llama-runner
parth-remove-claude-desktop-launch
hoyyeva/anthropic-reference-images-path
parth-anthropic-reference-images-path
brucemacd/download-before-remove
hoyyeva/editor-config-repair
parth-mlx-decode-checkpoints
parth-launch-codex-app
hoyyeva/fix-codex-model-metadata-warning
hoyyeva/qwen
parth/hide-claude-desktop-till-release
hoyyeva/opencode-image-modality
parth-add-claude-code-autoinstall
release_v0.22.0
pdevine/manifest-list
codex/fix-codex-model-metadata-warning
pdevine/addressable-manifest
brucemacd/launch-fetch-reccomended
jmorganca/llama-compat
launch-copilot-cli
hoyyeva/opencode-thinking
release_v0.20.7
parth-auto-save-backup
parth-test
jmorganca/gemma4-audio-replacements
fix-manifest-digest-on-pull
hoyyeva/vscode-improve
brucemacd/install-server-wait
parth/update-claude-docs
brucemac/start-ap-install
pdevine/mlx-update
pdevine/qwen35_vision
drifkin/api-show-fallback
mintlify/image-generation-1773352582
hoyyeva/server-context-length-local-config
jmorganca/faster-reptition-penalties
jmorganca/convert-nemotron
parth-pi-thinking
pdevine/sampling-penalties
jmorganca/fix-create-quantization-memory
dongchen/resumable_transfer_fix
pdevine/sampling-cache-error
jessegross/mlx-usage
hoyyeva/openclaw-config
hoyyeva/app-html
pdevine/qwen3next
brucemacd/sign-sh-install
brucemacd/tui-update
brucemacd/usage-api
jmorganca/launch-empty
fix-app-dist-embed
mxyng/mlx-compile
mxyng/mlx-quant
mxyng/mlx-glm4.7
mxyng/mlx
brucemacd/simplify-model-picker
jmorganca/qwen3-concurrent
fix-glm-4.7-flash-mla-config
drifkin/qwen3-coder-opening-tag
brucemacd/usage-cli
fix-cuda12-fattn-shmem
ollama-imagegen-docs
parth/fix-multiline-inputs
brucemacd/config-docs
mxyng/model-files
mxyng/simple-execute
fix-imagegen-ollama-models
mxyng/async-upload
jmorganca/lazy-no-dtype-changes
imagegen-auto-detect-create
parth/decrease-concurrent-download-hf
fix-mlx-quantize-init
jmorganca/x-cleanup
usage
imagegen-readme
jmorganca/glm-image
mlx-gpu-cd
jmorganca/imagegen-modelfile
parth/agent-skills
parth/agent-allowlist
parth/signed-in-offline
parth/agents
parth/fix-context-chopping
improve-cloud-flow
parth/add-models-websearch
parth/prompt-renderer-mcp
jmorganca/native-settings
jmorganca/download-stream-hash
jmorganca/client2-rebased
brucemacd/oai-chat-req-multipart
jessegross/multi_chunk_reserve
grace/additional-omit-empty
grace/mistral-3-large
mxyng/tokenizer2
mxyng/tokenizer
jessegross/flash
hoyyeva/windows-nacked-app
mxyng/cleanup-attention
grace/deepseek-parser
hoyyeva/remember-unsent-prompt
parth/add-lfs-pointer-error-conversion
parth/olmo2-test2
hoyyeva/ollama-launchagent-plist
nicole/olmo-model
parth/olmo-test
mxyng/remove-embedded
parth/render-template
jmorganca/intellect-3
parth/remove-prealloc-linter
jmorganca/cmd-eval
nicole/nomic-embed-text-fix
mxyng/lint-2
hoyyeva/add-gemini-3-pro-preview
hoyyeva/load-model-list
mxyng/expand-path
mxyng/environ-2
hoyyeva/deeplink-json-encoding
parth/improve-tool-calling-tests
hoyyeva/conversation
hoyyeva/assistant-edit-response
hoyyeva/thinking
origin/brucemacd/invalid-char-i-err
parth/improve-tool-calling
jmorganca/required-omitempty
grace/qwen3-vl-tests
mxyng/iter-client
parth/docs-readme
nicole/embed-test
pdevine/integration-benchstat
parth/remove-generate-cmd
parth/add-toolcall-id
mxyng/server-tests
jmorganca/glm-4.6
jmorganca/gin-h-compat
drifkin/stable-tool-args
pdevine/qwen3-more-thinking
parth/add-websearch-client
nicole/websearch_local
jmorganca/qwen3-coder-updates
grace/deepseek-v3-migration-tests
mxyng/fix-create
jmorganca/cloud-errors
pdevine/parser-tidy
revert-12233-parth/simplify-entrypoints-runner
parth/enable-so-gpt-oss
brucemacd/qwen3vl
jmorganca/readme-simplify
parth/gpt-oss-structured-outputs
revert-12039-jmorganca/tools-braces
mxyng/embeddings
mxyng/gguf
mxyng/benchmark
mxyng/types-null
parth/move-parsing
mxyng/gemma2
jmorganca/docs
mxyng/16-bit
mxyng/create-stdin
pdevine/authorizedkeys
mxyng/quant
parth/opt-in-error-context-window
brucemacd/cache-models
brucemacd/runner-completion
jmorganca/llama-update-6
brucemacd/benchmark-list
brucemacd/partial-read-caps
parth/deepseek-r1-tools
mxyng/omit-array
parth/tool-prefix-temp
brucemacd/runner-test
jmorganca/qwen25vl
brucemacd/model-forward-test-ext
parth/python-function-parsing
jmorganca/cuda-compression-none
drifkin/num-parallel
drifkin/chat-truncation-fix
jmorganca/sync
parth/python-tools-calling
drifkin/array-head-count
brucemacd/create-no-loop
parth/server-enable-content-stream-with-tools
qwen25omni
mxyng/v3
brucemacd/ropeconfig
jmorganca/silence-tokenizer
parth/sample-so-test
parth/sampling-structured-outputs
brucemacd/doc-go-engine
parth/constrained-sampling-json
jmorganca/mistral-wip
brucemacd/mistral-small-convert
parth/sample-unmarshal-json-for-params
brucemacd/jomorganca/mistral
pdevine/bfloat16
jmorganca/mistral
brucemacd/mistral
pdevine/logging
parth/sample-correctness-fix
parth/sample-fix-sorting
jmorgan/sample-fix-sorting-extras
jmorganca/temp-0-images
brucemacd/parallel-embed-models
brucemacd/shim-grammar
jmorganca/fix-gguf-error
bmizerany/nameswork
jmorganca/faster-releases
bmizerany/validatenames
brucemacd/err-no-vocab
brucemacd/rope-config
brucemacd/err-hint
brucemacd/qwen2_5
brucemacd/logprobs
brucemacd/new_runner_graph_bench
progress-flicker
brucemacd/forward-test
brucemacd/go_qwen2
pdevine/gemma2
jmorganca/add-missing-symlink-eval
mxyng/next-debug
parth/set-context-size-openai
brucemacd/next-bpe-bench
brucemacd/next-bpe-test
brucemacd/new_runner_e2e
brucemacd/new_runner_qwen2
pdevine/convert-cohere2
brucemacd/convert-cli
parth/log-probs
mxyng/next-mlx
mxyng/cmd-history
parth/templating
parth/tokenize-detokenize
brucemacd/check-key-register
bmizerany/grammar
jmorganca/vendor-081b29bd
mxyng/func-checks
jmorganca/fix-null-format
parth/fix-default-to-warn-json
jmorganca/qwen2vl
jmorganca/no-concat
parth/cmd-cleanup-SO
brucemacd/check-key-register-structured-err
parth/openai-stream-usage
parth/fix-referencing-so
stream-tools-stop
jmorganca/degin-1
brucemacd/install-path-clean
brucemacd/push-name-validation
brucemacd/browser-key-register
jmorganca/openai-fix-first-message
jmorganca/fix-proxy
jessegross/sample
parth/disallow-streaming-tools
dhiltgen/remove_submodule
jmorganca/ga
jmorganca/mllama
pdevine/newlines
pdevine/geems-2b
jmorganca/llama-bump
mxyng/modelname-7
mxyng/gin-slog
mxyng/modelname-6
jyan/convert-prog
jyan/quant5
paligemma-support
pdevine/import-docs
jmorganca/openai-context
jyan/paligemma
jyan/p2
jyan/palitest
bmizerany/embedspeedup
jmorganca/llama-vit
brucemacd/allow-ollama
royh/ep-methods
royh/whisper
mxyng/api-models
mxyng/fix-memory
jyan/q4_4/8
jyan/ollama-v
royh/stream-tools
roy-embed-parallel
bmizerany/hrm
revert-5963-revert-5924-mxyng/llama3.1-rope
royh/embed-viz
jyan/local2
jyan/auth
jyan/local
jyan/parse-temp
jmorganca/template-mistral
jyan/reord-g
royh-openai-suffixdocs
royh-imgembed
royh-embed-parallel
jyan/quant4
royh-precision
jyan/progress
pdevine/fix-template
jyan/quant3
pdevine/ggla
mxyng/update-registry-domain
jmorganca/ggml-static
mxyng/create-context
jyan/v0.146
mxyng/layers-from-files
build_dist
bmizerany/noseek
royh-ls
royh-name
timeout
mxyng/server-timestamp
bmizerany/nosillyggufslurps
royh-params
jmorganca/llama-cpp-7c26775
royh-openai-delete
royh-show-rigid
jmorganca/enable-fa
jmorganca/no-error-template
jyan/format
royh-testdelete
bmizerany/fastverify
language_support
pdevine/ps-glitches
brucemacd/tokenize
bruce/iq-quants
bmizerany/filepathwithcoloninhost
mxyng/split-bin
bmizerany/client-registry
jmorganca/if-none-match
native
jmorganca/native
jmorganca/batch-embeddings
jmorganca/initcmake
jmorganca/mm
pdevine/showggmlinfo
modenameenforcealphanum
bmizerany/modenameenforcealphanum
jmorganca/done-reason
jmorganca/llama-cpp-8960fe8
ollama.com
bmizerany/filepathnobuild
bmizerany/types/model/defaultfix
rmdisplaylong
nogogen
bmizerany/x
modelfile-readme
bmizerany/replacecolon
jmorganca/limit
jmorganca/execstack
jmorganca/replace-assets
mxyng/tune-concurrency
jmorganca/testing
whitespace-detection
jmorganca/options
upgrade-all
scratch
cuda-search
mattw/airenamer
mattw/allmodelsonhuggingface
mattw/quantcontext
mattw/whatneedstorun
brucemacd/llama-mem-calc
mattw/faq-context
mattw/communitylinks
mattw/noprune
mattw/python-functioncalling
rename
mxyng/install
pulse
remove-first
editor
mattw/selfqueryingretrieval
cgo
mattw/howtoquant
api
matt/streamingapi
format-config
mxyng/extra-args
shell
update-nous-hermes
cp-model
upload-progress
fix-unknown-model
fix-model-names
delete-fix
insecure-registry
ls
deletemodels
progressbar
readme-updates
license-layers
skip-list
list-models
modelpath
matt/examplemodelfiles
distribution
go-opts
v0.30.0-rc3
v0.30.0-rc2
v0.30.0-rc1
v0.30.0-rc0
v0.23.1
v0.23.1-rc0
v0.23.0
v0.23.0-rc0
v0.22.1
v0.22.1-rc1
v0.22.1-rc0
v0.22.0
v0.22.0-rc1
v0.21.3-rc0
v0.21.2-rc1
v0.21.2
v0.21.2-rc0
v0.21.1
v0.21.1-rc1
v0.21.1-rc0
v0.21.0
v0.21.0-rc1
v0.21.0-rc0
v0.20.8-rc0
v0.20.7
v0.20.7-rc1
v0.20.7-rc0
v0.20.6
v0.20.6-rc1
v0.20.6-rc0
v0.20.5
v0.20.5-rc2
v0.20.5-rc1
v0.20.5-rc0
v0.20.4
v0.20.4-rc2
v0.20.4-rc1
v0.20.4-rc0
v0.20.3
v0.20.3-rc0
v0.20.2
v0.20.1
v0.20.1-rc2
v0.20.1-rc1
v0.20.1-rc0
v0.20.0
v0.20.0-rc1
v0.20.0-rc0
v0.19.0
v0.19.0-rc2
v0.19.0-rc1
v0.19.0-rc0
v0.18.4-rc1
v0.18.4-rc0
v0.18.3
v0.18.3-rc2
v0.18.3-rc1
v0.18.3-rc0
v0.18.2
v0.18.2-rc1
v0.18.2-rc0
v0.18.1
v0.18.1-rc1
v0.18.1-rc0
v0.18.0
v0.18.0-rc2
v0.18.0-rc1
v0.18.0-rc0
v0.17.8-rc4
v0.17.8-rc3
v0.17.8-rc2
v0.17.8-rc1
v0.17.8-rc0
v0.17.7
v0.17.7-rc2
v0.17.7-rc1
v0.17.7-rc0
v0.17.6
v0.17.5
v0.17.4
v0.17.3
v0.17.2
v0.17.1
v0.17.1-rc2
v0.17.1-rc1
v0.17.1-rc0
v0.17.0
v0.17.0-rc2
v0.17.0-rc1
v0.17.0-rc0
v0.16.3
v0.16.3-rc2
v0.16.3-rc1
v0.16.3-rc0
v0.16.2
v0.16.2-rc0
v0.16.1
v0.16.0
v0.16.0-rc2
v0.16.0-rc0
v0.16.0-rc1
v0.15.6
v0.15.5
v0.15.5-rc5
v0.15.5-rc4
v0.15.5-rc3
v0.15.5-rc2
v0.15.5-rc1
v0.15.5-rc0
v0.15.4
v0.15.3
v0.15.2
v0.15.1
v0.15.1-rc1
v0.15.1-rc0
v0.15.0-rc6
v0.15.0
v0.15.0-rc5
v0.15.0-rc4
v0.15.0-rc3
v0.15.0-rc2
v0.15.0-rc1
v0.15.0-rc0
v0.14.3
v0.14.3-rc3
v0.14.3-rc2
v0.14.3-rc1
v0.14.3-rc0
v0.14.2
v0.14.2-rc1
v0.14.2-rc0
v0.14.1
v0.14.0-rc11
v0.14.0
v0.14.0-rc10
v0.14.0-rc9
v0.14.0-rc8
v0.14.0-rc7
v0.14.0-rc6
v0.14.0-rc5
v0.14.0-rc4
v0.14.0-rc3
v0.14.0-rc2
v0.14.0-rc1
v0.14.0-rc0
v0.13.5
v0.13.5-rc1
v0.13.5-rc0
v0.13.4-rc2
v0.13.4
v0.13.4-rc1
v0.13.4-rc0
v0.13.3
v0.13.3-rc1
v0.13.3-rc0
v0.13.2
v0.13.2-rc2
v0.13.2-rc1
v0.13.2-rc0
v0.13.1
v0.13.1-rc2
v0.13.1-rc1
v0.13.1-rc0
v0.13.0
v0.13.0-rc0
v0.12.11
v0.12.11-rc1
v0.12.11-rc0
v0.12.10
v0.12.10-rc1
v0.12.10-rc0
v0.12.9-rc0
v0.12.9
v0.12.8
v0.12.8-rc0
v0.12.7
v0.12.7-rc1
v0.12.7-rc0
v0.12.7-citest0
v0.12.6
v0.12.6-rc1
v0.12.6-rc0
v0.12.5
v0.12.5-rc0
v0.12.4
v0.12.4-rc7
v0.12.4-rc6
v0.12.4-rc5
v0.12.4-rc4
v0.12.4-rc3
v0.12.4-rc2
v0.12.4-rc1
v0.12.4-rc0
v0.12.3
v0.12.2
v0.12.2-rc0
v0.12.1
v0.12.1-rc1
v0.12.1-rc2
v0.12.1-rc0
v0.12.0
v0.12.0-rc1
v0.12.0-rc0
v0.11.11
v0.11.11-rc3
v0.11.11-rc2
v0.11.11-rc1
v0.11.11-rc0
v0.11.10
v0.11.9
v0.11.9-rc0
v0.11.8
v0.11.8-rc0
v0.11.7-rc1
v0.11.7-rc0
v0.11.7
v0.11.6
v0.11.6-rc0
v0.11.5-rc4
v0.11.5-rc3
v0.11.5
v0.11.5-rc5
v0.11.5-rc2
v0.11.5-rc1
v0.11.5-rc0
v0.11.4
v0.11.4-rc0
v0.11.3
v0.11.3-rc0
v0.11.2
v0.11.1
v0.11.0-rc0
v0.11.0-rc1
v0.11.0-rc2
v0.11.0
v0.10.2-int1
v0.10.1
v0.10.0
v0.10.0-rc4
v0.10.0-rc3
v0.10.0-rc2
v0.10.0-rc1
v0.10.0-rc0
v0.9.7-rc1
v0.9.7-rc0
v0.9.6
v0.9.6-rc0
v0.9.6-ci0
v0.9.5
v0.9.4-rc5
v0.9.4-rc6
v0.9.4
v0.9.4-rc3
v0.9.4-rc4
v0.9.4-rc1
v0.9.4-rc2
v0.9.4-rc0
v0.9.3
v0.9.3-rc5
v0.9.4-citest0
v0.9.3-rc4
v0.9.3-rc3
v0.9.3-rc2
v0.9.3-rc1
v0.9.3-rc0
v0.9.2
v0.9.1
v0.9.1-rc1
v0.9.1-rc0
v0.9.1-ci1
v0.9.1-ci0
v0.9.0
v0.9.0-rc0
v0.8.0
v0.8.0-rc0
v0.7.1-rc2
v0.7.1
v0.7.1-rc1
v0.7.1-rc0
v0.7.0
v0.7.0-rc1
v0.7.0-rc0
v0.6.9-rc0
v0.6.8
v0.6.8-rc0
v0.6.7
v0.6.7-rc2
v0.6.7-rc1
v0.6.7-rc0
v0.6.6
v0.6.6-rc2
v0.6.6-rc1
v0.6.6-rc0
v0.6.5-rc1
v0.6.5
v0.6.5-rc0
v0.6.4-rc0
v0.6.4
v0.6.3-rc1
v0.6.3
v0.6.3-rc0
v0.6.2
v0.6.2-rc0
v0.6.1
v0.6.1-rc0
v0.6.0-rc0
v0.6.0
v0.5.14-rc0
v0.5.13
v0.5.13-rc6
v0.5.13-rc5
v0.5.13-rc4
v0.5.13-rc3
v0.5.13-rc2
v0.5.13-rc1
v0.5.13-rc0
v0.5.12
v0.5.12-rc1
v0.5.12-rc0
v0.5.11
v0.5.10
v0.5.9
v0.5.9-rc0
v0.5.8-rc13
v0.5.8
v0.5.8-rc12
v0.5.8-rc11
v0.5.8-rc10
v0.5.8-rc9
v0.5.8-rc8
v0.5.8-rc7
v0.5.8-rc6
v0.5.8-rc5
v0.5.8-rc4
v0.5.8-rc3
v0.5.8-rc2
v0.5.8-rc1
v0.5.8-rc0
v0.5.7
v0.5.6
v0.5.5
v0.5.5-rc0
v0.5.4
v0.5.3
v0.5.3-rc0
v0.5.2
v0.5.2-rc3
v0.5.2-rc2
v0.5.2-rc1
v0.5.2-rc0
v0.5.1
v0.5.0
v0.5.0-rc1
v0.4.8-rc0
v0.4.7
v0.4.6
v0.4.5
v0.4.4
v0.4.3
v0.4.3-rc0
v0.4.2
v0.4.2-rc1
v0.4.2-rc0
v0.4.1
v0.4.1-rc0
v0.4.0
v0.4.0-rc8
v0.4.0-rc7
v0.4.0-rc6
v0.4.0-rc5
v0.4.0-rc4
v0.4.0-rc3
v0.4.0-rc2
v0.4.0-rc1
v0.4.0-rc0
v0.4.0-ci3
v0.3.14
v0.3.14-rc0
v0.3.13
v0.3.12
v0.3.12-rc5
v0.3.12-rc4
v0.3.12-rc3
v0.3.12-rc2
v0.3.12-rc1
v0.3.11
v0.3.11-rc4
v0.3.11-rc3
v0.3.11-rc2
v0.3.11-rc1
v0.3.10
v0.3.10-rc1
v0.3.9
v0.3.8
v0.3.7
v0.3.7-rc6
v0.3.7-rc5
v0.3.7-rc4
v0.3.7-rc3
v0.3.7-rc2
v0.3.7-rc1
v0.3.6
v0.3.5
v0.3.4
v0.3.3
v0.3.2
v0.3.1
v0.3.0
v0.2.8
v0.2.8-rc2
v0.2.8-rc1
v0.2.7
v0.2.6
v0.2.5
v0.2.4
v0.2.3
v0.2.2
v0.2.2-rc2
v0.2.2-rc1
v0.2.1
v0.2.0
v0.1.49-rc14
v0.1.49-rc13
v0.1.49-rc12
v0.1.49-rc11
v0.1.49-rc10
v0.1.49-rc9
v0.1.49-rc8
v0.1.49-rc7
v0.1.49-rc6
v0.1.49-rc4
v0.1.49-rc5
v0.1.49-rc3
v0.1.49-rc2
v0.1.49-rc1
v0.1.48
v0.1.47
v0.1.46
v0.1.45-rc5
v0.1.45
v0.1.45-rc4
v0.1.45-rc3
v0.1.45-rc2
v0.1.45-rc1
v0.1.44
v0.1.43
v0.1.42
v0.1.41
v0.1.40
v0.1.40-rc1
v0.1.39
v0.1.39-rc2
v0.1.39-rc1
v0.1.38
v0.1.37
v0.1.36
v0.1.35
v0.1.35-rc1
v0.1.34
v0.1.34-rc1
v0.1.33
v0.1.33-rc7
v0.1.33-rc6
v0.1.33-rc5
v0.1.33-rc4
v0.1.33-rc3
v0.1.33-rc2
v0.1.33-rc1
v0.1.32
v0.1.32-rc2
v0.1.32-rc1
v0.1.31
v0.1.30
v0.1.29
v0.1.28
v0.1.27
v0.1.26
v0.1.25
v0.1.24
v0.1.23
v0.1.22
v0.1.21
v0.1.20
v0.1.19
v0.1.18
v0.1.17
v0.1.16
v0.1.15
v0.1.14
v0.1.13
v0.1.12
v0.1.11
v0.1.10
v0.1.9
v0.1.8
v0.1.7
v0.1.6
v0.1.5
v0.1.4
v0.1.3
v0.1.2
v0.1.1
v0.1.0
v0.0.21
v0.0.20
v0.0.19
v0.0.18
v0.0.17
v0.0.16
v0.0.15
v0.0.14
v0.0.13
v0.0.12
v0.0.11
v0.0.10
v0.0.9
v0.0.8
v0.0.7
v0.0.6
v0.0.5
v0.0.4
v0.0.3
v0.0.2
v0.0.1
Labels
Clear labels
amd
api
app
bug
build
cli
cloud
compatibility
context-length
create
docker
documentation
embeddings
feature request
feedback wanted
good first issue
gpt-oss
gpu
harmony
help wanted
image
install
intel
js
launch
linux
macos
memory
mlx
model
needs more info
networking
nvidia
ollama.com
performance
pull-request
python
question
registry
rendering
thinking
tools
top
vulkan
windows
wsl
Mirrored from GitHub Pull Request
No Label
bug
Milestone
No items
No Milestone
Projects
Clear projects
No project
No Assignees
Notifications
Due Date
No due date set.
Dependencies
No dependencies set.
Reference: github-starred/ollama#65369
Reference in New Issue
Block a user
Blocking a user prevents them from interacting with repositories, such as opening or commenting on pull requests or issues. Learn more about blocking a user.
Delete Branch "%!s()"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
Originally created by @d-kleine on GitHub (Jun 27, 2024).
Original GitHub issue: https://github.com/ollama/ollama/issues/5321
What is the issue?
Follow-up of #586
Even though the output is deterministic and reproducible with a fixed
seed, atemperatureset to 0 and a fixednum_ctx, the generated output of Llama 3 slightly differs in the first executing of this code and the second execution of this code (without kernel restart). The following executions will be the same as for the second execution:Code snippet taken from LLMs from scratch - Evaluation with Ollama:
Output of execution no.
1(output can vary):Output for execution no.
2to execution no.n(output should be reproducible):Observations:
Linux, macOS, Windows, Docker, WSL2
GPU
Nvidia
CPU
AMD
Ollama version
0.1.46
@mitar commented on GitHub (Jul 17, 2024):
Your version does include
ead259d877so I am not sure why.@sayap commented on GitHub (Jul 18, 2024):
Can try to apply this patch:
The
cache_promptflag was set to true by commita64570dca. From https://github.com/ggerganov/llama.cpp/tree/master/examples/server#api-endpoints, it says:Once I have applied this patch, I can get the exact same output when sending the same prompt with the same
seedand the sametemperature, regardless of kernel restart. For example:1st output:
2nd output:
I guess this flag should be made configurable?
@d-kleine commented on GitHub (Jul 18, 2024):
To be honest, idk if that fixes the issue. For the output, you have used a different model, a different prompt and not validated it across different OS.
The KV cache is actually a helpful feature, but it might be initialized differently across different OS. Therefore, disabling it might fix this issue, but does not solve the resolve the issue with KV caching initialization.
https://github.com/ggerganov/llama.cpp/issues/4902
But you actually brought me on a idea to bypass KV cache just by setting
num_keep=0(this would not disable it, but at least no tokens will be stored in the cache then).Idk how to install ollama with your changes, neither on Ubuntu nor on Windows, I will test it once the new Ollama version having it implemented will be released. Thanks for the PR anyways!
BTW I have also opened a PR on llama.cpp for making outputs 100% deterministic:
https://github.com/ggerganov/llama.cpp/discussions/8265
When using
temperature=0, a small coefficient might be used to prevent zero division. In some cases, this might slightly change the generated output, depending on the model used. Therefore, it would be better to turn off beam search and multinomial sampling for deterministic sampling.Setting a
seedonly makes sense when using non-deterministic sampling, such as Top-k or Top-p sampling, to ensure reproducibility. This example code here for Ollama doesn't fully make sense because you would not need to set aseedas withtemperature=0the generated output would be deterministic anyways. But when setting a temperature > 0.0, you would need to set a seed as well to make the output reproducible.@psambit9791 commented on GitHub (Jan 2, 2025):
Is there any update on this issue?
I have also been encountering this issue.
@sisp commented on GitHub (Feb 26, 2025):
Same here with Microsoft's Phi-3 Mini model.
@lemassykoi commented on GitHub (Mar 16, 2025):
Same here with Ollama 0.6.1
I wrote a script to test with models : https://gist.github.com/lemassykoi/e1423068d1d976961953d86609877fd5
Seed and Temperature are fixed values.
For each model in the list, the script restart ollama service, then send the same query twice to the model.
If the model output is strictly identical, test is ok.
If the model output is different from pass 1 to pass 2, test is failed.
@kevin-pw commented on GitHub (Mar 26, 2025):
I can confirm inconsistent outputs with Ollama
v0.6.3-rc0for several models I tested:llama3.2:latestllama3.2-vision:11bgemma3:12bI noticed that
llama3.2:latestproduces inconsistent results for identical inputs not only on the/generateendpoint but also on the/embedendpoint. That means the model produces different probability distributions for the same inputs. That also means that adjusting sampler options (liketemperature,seed,top_p,top_k, etc.) cannot fix this problem.The inconsistent outputs present a serious issue because they degrade the quality of any downstream applications. In my tests, the cosine similarity between the different embeddings for the same inputs was as low as 99.4%, but I suspect that similarity could drop even lower. Especially for RAG applications where similarities for large datasets often fall within a small range, this inconsistency appears beyond acceptable.
I am not including any code here because the inconsistencies are somewhat difficult to reproduce. I have found the inconsistencies to occur on the same machine when I run Ollama on CPU vs GPU, and when running Ollama using the installed version vs a docker container vs compiling the development version from source. Sometimes these changes cause the inconsistencies, and sometimes they do not.
@rick-github Could you take another look at this issue? I believe the inconsistencies are serious enough to warrant attention, but I know you have a long list of priorities. Shout-out to the Ollama team for your amazing work!
@rick-github commented on GitHub (Mar 26, 2025):
A quick first pass with generating embeddings with llama3.2:latest failed to show any inconsistencies. Can you give me an idea of the type of input, chunk length and context length you are using?
@flexorx commented on GitHub (Mar 26, 2025):
@rick-github please consider running this, for instance:
https://github.com/ollama/ollama/issues/5321#issuecomment-2727539874
https://gist.github.com/lemassykoi/e1423068d1d976961953d86609877fd5
In our case, the issue is present with pretty much any prompt and with temperature, top_k, top_p, seed all zeroed out and repetition_penalty set to 1.0 (all non-determinism OFF), on various quantizations of mistral 24B, mistral 24B-3.1, gemma3 of various sizes, and various window sizes like 2k, 4k, 8k, on Nvidia GPU both with and without parallelism for ollama set (surely on the same OS and the same computer).
@lemassykoi commented on GitHub (Mar 27, 2025):
New version with embed testing: https://gist.github.com/lemassykoi/5a6c0d655b5923e9588eef68d12fcbd2
It will test all your available models from your local ollama, excepted some with special words in model name like
code, orembed(which can't chat or generate)Some models are not capable of embedding, they will appear as invalid models without raising exception.
ollama 0.6.2
@kevin-pw commented on GitHub (Mar 27, 2025):
@rick-github Thank you for looking into this issue!
I am able to reproduce the inconsistent embedding results by running Ollama compiled from source and setting
CUDA_VISIBLE_DEVICESeither to0or to-1.First, I compile from source as usual. I am using
v0.6.3-rc0, which is commite5d84fb:cmake -B buildcmake --build buildI then explicitly use the GPU by setting:
export CUDA_VISIBLE_DEVICES=0Then, run Ollama:
go run . serveIn a separate terminal, I issue a curl command to Ollama:
The response is:
{"model":"llama3.2:latest","embeddings":[[0.00094434456,0.013325062,-0.026114173,...,-0.03395605,0.0069057234,-0.010010444]],"total_duration":1517785083,"load_duration":1414543924,"prompt_eval_count":16}I then stop Ollama with CTRL + C, and explicitly use the CPU by setting:
export CUDA_VISIBLE_DEVICES=-1Then, run Ollama:
go run . serveUsing the same curl command as above, the response is:
{"model":"llama3.2:latest","embeddings":[[0.0010927601,0.014933009,-0.024558352,...,-0.03514633,0.0063932026,-0.009075297]],"total_duration":225640647,"load_duration":1596095,"prompt_eval_count":16}As you can see, using the GPU vs CPU to compute identical curl requests results in different embeddings. The cosine similarity between the different embeddings is 99.86% in this example but I have seen lower similarities. Ideally, the similarity should be 100%. I haven’t changed any parameters other than using the GPU vs CPU. That means my curl requests use the default context length, chunk size, etc.
@lemassykoi I tried your script but was unable to reproduce the inconsistent embedding results. Your script restarts Ollama without changing any other settings (like switching from GPU to CPU), is that right? A restart alone did not produce inconsistencies for me. Perhaps this depends on the machine Ollama is running on. I use Ubuntu Linux 24.10.
@lemassykoi commented on GitHub (Mar 27, 2025):
Yes, I don't switch between CPU and GPU
I'm using Debian 12
@sisp commented on GitHub (Mar 27, 2025):
@kevin-pw Consistency across CPU and GPU cannot be guaranteed: https://pytorch.org/docs/stable/notes/randomness.html
@kevin-pw commented on GitHub (Mar 28, 2025):
Summary
It looks like three different issues might cause different embeddings, logits, and text generation for the same inputs:
After some digging, issues 1) and 2) do not appear to have an easy fix. These issues could be problematic for downstream applications like RAG, but it may be possible to mitigate those issues by carefully working around them. Issue 3) can be addressed by avoiding the use of the KV cache.
@rick-github I have two suggestions:
Apologies for the long rant – this was a deep dive.
1) Generating the output on different operating systems
Observing different results on different operating systems is consistent with the following issues posted on the llama.cpp repo:
https://github.com/ggml-org/llama.cpp/issues/2582
https://github.com/ggml-org/llama.cpp/discussions/2100#discussioncomment-6353790
Unfortunately, the two issues above were never fully resolved. The second issue suggests building a portable binary by statically linking the CUDA libraries, but I haven’t tested if that approach actually achieves consistent results between operating systems.
In my tests, I ran Ollama compiled from source directly on Ubuntu Linux 24.10, and I ran Ollama within a docker container that uses Ubuntu Linux 20.04. While running both of those Ollama instances on the same machine,
gemma3:12bproduced different results for the same text + image input.To reproduce this issue:
Compile Ollama
v0.6.3-rc0from source as described in the docs:Run Ollama with
go run . serveRun the python code below **
On my Ubuntu Linux 24.10 machine, the response is:
{"most_likely_text": "nbg2m", "less_likely_text": "nbg2m"}This response contains the correct letters and numbers shown in the image.To receive a different response, stop Ollama with CTRL + C and:
Build the docker image with
docker build -t ollama .as described in the docs.Run Ollama in docker with
docker run --gpus=all -v ollama:/root/.ollama -p 127.0.0.1:11434:11434 --name ollama ollamaRun the python code below **
On my machine, the response is:
{"most_likely_text": "nby2m", "less_likely_text": "nbyzm"}This response DOES NOT contain the correct letters and numbers shown in the image.2) Generating the outputs on CPU vs GPU
As described in the PyTorch docs, consistency across CPU and GPU cannot be guaranteed. In fact, a large number of CUDA algorithms are non-deterministic. Some algorithms have deterministic but slower equivalents, but several other algorithms cannot behave deterministically at all (Thank you, @sisp !). So it may not ever be possible to produce consistent outputs with Ollama across different hardware.
To reproduce this issue:
Run Ollama within the docker image built in section 1) and using the additional flags
-e CUDA_VISIBLE_DEVICES=0(for GPU) or-e CUDA_VISIBLE_DEVICES=-1(for CPU).Run the python code below **
On GPU, the response was:
{"most_likely_text": "nby2m", "less_likely_text": "nbyzm"}On CPU, the response was:
{"most_likely_text": "nby2m", "less_likely_text": "nbytm"}(withtinstead ofz, but both responses are incorrect.)3) Generating the outputs using the KV cache
The KV cache temporarily stores part of a prompt and its response so that identical prompts do not have to be regenerated by the model when submitting the same prompt twice. However, using the stored results can cause the
/generateendpoint to produce different results when processing the same inputs twice in quick succession.The
/embedendpoint is unaffected by this issue because using the KV cache is explicitly set tofalse01aa788722/runner/llamarunner/runner.go (L703)Relevant known issues related to KV cache:
https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
https://github.com/ggml-org/llama.cpp/issues/3014
To reproduce this issue:
Run Ollama within the docker image built in section 1) and using the additional flags
-e CUDA_VISIBLE_DEVICES=-1(for CPU) and then run the the python code below ** twice.First response:
{"most_likely_text": "nby2m", "less_likely_text": "nbytm"}Second response:
{"most_likely_text": "n2m", "less_likely_text": "bg2m"}As you can see, the first and second response are completely different. Both responses are incorrect when compared to the letters and numbers contained in the image.
Why different results for identical inputs present a significant problem
Downstream applications like retrieval augmented generation (RAG), classification, annotation, or semantic search depend on consistent embedding, logit, and text generation. Those applications usually compare similarity between embeddings, so any uncertainty in embeddings reduces the quality of results that those applications can produce. For large datasets, similarities between embeddings often fall within a relatively small range, so even small differences in embeddings can make those applications unusable.
Possible workarounds to generate consistent outputs for the same inputs (to be confirmed):
/generateendpoint for the llama runner:01aa788722/runner/llamarunner/runner.go (L611)or for the ollama runner:
01aa788722/runner/ollamarunner/runner.go (L600)** Click to reveal code used to investigate all three causes of different results for the same inputs
image “lettersandnumbers.jpg”:
@flexorx commented on GitHub (Mar 30, 2025):
I can't get it @kevin-pw , in our case we are running this stuff on the same OS (RedHat), on the same computer, in the same environment, without any docker, without parallelism and purely on GPU. We do NOT do anything explicit or special rgd KV cache, we just set all temperature, top_P top_K and the rest of this stuff to 0, repetition_penalty to 1.0 to remove ANY non-determinism and assure results fully reproducible. What is the cause in this case? The whole discussion is going in a sentiment like "there's nothing we can do", but the issue has obviously not been an issue previously in our conditions and only occurred sometime 2024 as a kind of bug.
@kevin-pw commented on GitHub (Mar 30, 2025):
@flexorx Does the following correctly describe your issue when you submit multiple identical input queries?
If that is the case, the issue is caused by the KV cache. In the current version
0.6.3of Ollama, no input parameter is available to disable the KV cache on the/generateendpoint. To eliminate the issue, you would need to implement workaround 3) in my comment above by modifying the two lines in the source code.@flexorx commented on GitHub (Mar 30, 2025):
@kevin-pw yes, indeed, it is often that the first query result is different from the second and on consecutive queries for the same input. However, if we also do some other input previously to the first input of the specific query, then the first result for this specific query after this other input would differ from the first result of this specific query but without this other input prior to it.
So, essentially, we can say two things:
It is always that the result of the first input of one particular query is different from the second and on inputs of that same query, given that all inputs preceding to that first input remain intact, and that no other queries interleave these first, second and on inputs of this particular query.
Generally, if we perform a sequence of queries A,B,C,D, then the results would vary for any permutation of these queries' order, such as B,C,A,D, D,A,C,B etc. Therefore, in general, the results are not immutable to the permutation of queries and are path-dependent therefore.
@JakeBeaver commented on GitHub (Mar 31, 2025):
I found some simple repro steps for the
/api/chatendpointRepeating this gives the same result in a loop. After the first response,
message.contentloses theHello!andeval_countchanges from18to16for all subsequent responses.At least its enough to unload with an API request, so I don't have to rely on shell scripts rebooting all of ollama, but still, seems wasteful to keep
keep_aliveas0and force the model unload after every request.@wyli commented on GitHub (Mar 31, 2025):
disabling kvcache as mentioned by @kevin-pw works for me.. I put a possible implementation here https://github.com/ollama/ollama/pull/10064 to make it configurable.
@d-kleine commented on GitHub (Apr 4, 2025):
@sayap already identified the KV cache as the root of the problem with generating consistent reproducible outputs almost a year ago. He also submitted a fix which worked at least when I tested it back then (#5760), but this PR has not been merged.
I have also tried to make the PRNG for the cache initialization consistent with a seed, but never got this really working. The thing about disabling the KV caching forces the LLM fully recompute the attention matrices, therefore using more memory than having the KV caching enabled.
@wyli commented on GitHub (Apr 4, 2025):
not sure why that PR was not considered as well... in general these don't change the default and increase flexibility.
(I think the analysis here has already demonstrate the numerical differences https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535)
@Jonas-Wessner commented on GitHub (Feb 25, 2026):
I confirm the bug on ubuntu 24.04.3, running on L40s GPUs.
When I execute the script for the first time, I get a random output for iteration 0 of the loop. For subsequent iterations, I get a different, but consistent output.
If I run the script again, the bug is gone.
If I change something about the prompt (causing some cache reload I suppose), the bug can be reproduced again.
I hope this can be fixed soon, since otherwise it is hard to guarantee reproducible experiment results.