mirror of
https://github.com/ollama/ollama.git
synced 2025-12-05 18:46:22 -06:00
deepseekocr
This commit is contained in:
committed by
Michael Yang
parent
8ed1adf3db
commit
92981ae3f2
@@ -206,6 +206,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &commandrModel{}
|
||||
case "GptOssForCausalLM":
|
||||
conv = &gptossModel{}
|
||||
case "DeepseekOCRForCausalLM":
|
||||
conv = &deepseekocr{}
|
||||
default:
|
||||
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||
}
|
||||
|
||||
136
convert/convert_deepseekocr.go
Normal file
136
convert/convert_deepseekocr.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type deepseekocr struct {
|
||||
ModelParameters
|
||||
LanguageConfig struct {
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
NumRoutedExperts uint32 `json:"n_routed_experts"`
|
||||
NumSharedExperts uint32 `json:"n_shared_experts"`
|
||||
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||
FirstKDenseReplace uint32 `json:"first_k_dense_replace"`
|
||||
} `json:"language_config"`
|
||||
|
||||
VisionConfig struct {
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
Width struct {
|
||||
Vision struct {
|
||||
Heads uint32 `json:"heads"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
Layers uint32 `json:"layers"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
Width uint32 `json:"width"`
|
||||
} `json:"clip-l-14-224"`
|
||||
Sam struct {
|
||||
GlobalAttentionIndexes []int32 `json:"global_attn_indexes"`
|
||||
Heads uint32 `json:"heads"`
|
||||
Layers uint32 `json:"layers"`
|
||||
Width uint32 `json:"width"`
|
||||
} `json:"sam_vit_b"`
|
||||
}
|
||||
} `json:"vision_config"`
|
||||
}
|
||||
|
||||
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
|
||||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "deepseekocr"
|
||||
kv["block_count"] = m.LanguageConfig.HiddenLayers
|
||||
kv["context_length"] = m.LanguageConfig.MaxPositionEmbeddings
|
||||
kv["embedding_length"] = m.LanguageConfig.HiddenSize
|
||||
kv["feed_forward_length"] = m.LanguageConfig.IntermediateSize
|
||||
kv["attention.head_count"] = m.LanguageConfig.NumAttentionHeads
|
||||
kv["attention.head_count_kv"] = m.LanguageConfig.NumKeyValueHeads
|
||||
kv["expert_count"] = m.LanguageConfig.NumRoutedExperts
|
||||
kv["expert_used_count"] = m.LanguageConfig.NumExpertsPerToken
|
||||
kv["leading_dense_block_count"] = m.LanguageConfig.FirstKDenseReplace
|
||||
|
||||
kv["vision.block_count"] = m.VisionConfig.Width.Vision.Layers
|
||||
kv["vision.embedding_length"] = m.VisionConfig.Width.Vision.Width
|
||||
kv["vision.head_count"] = m.VisionConfig.Width.Vision.Heads
|
||||
kv["vision.image_size"] = m.VisionConfig.Width.Vision.ImageSize
|
||||
kv["vision.patch_size"] = m.VisionConfig.Width.Vision.PatchSize
|
||||
|
||||
kv["sam.block_count"] = m.VisionConfig.Width.Sam.Layers
|
||||
kv["sam.embedding_length"] = m.VisionConfig.Width.Sam.Width
|
||||
kv["sam.head_count"] = m.VisionConfig.Width.Sam.Heads
|
||||
kv["sam.global_attention_indexes"] = m.VisionConfig.Width.Sam.GlobalAttentionIndexes
|
||||
return kv
|
||||
}
|
||||
|
||||
func (m *deepseekocr) Tensors(s []Tensor) (out []*ggml.Tensor) {
|
||||
merges := make([]merge, m.LanguageConfig.HiddenLayers*3)
|
||||
for i := range m.LanguageConfig.HiddenLayers {
|
||||
merges[i*3+0] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||
}
|
||||
merges[i*3+1] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||
}
|
||||
merges[i*3+2] = merge{
|
||||
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
|
||||
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||
}
|
||||
}
|
||||
|
||||
out, s = mergeTensors(s, merges...)
|
||||
for _, t := range s {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *deepseekocr) Replacements() []string {
|
||||
return []string{
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
|
||||
"mlp.shared_experts.up_proj", "ffn_up_shexp",
|
||||
"mlp.shared_experts.down_proj", "ffn_down_shexp",
|
||||
"model.norm", "output_norm",
|
||||
"lm_head", "output",
|
||||
|
||||
"model.vision_model", "v",
|
||||
"embeddings.patch_embedding", "patch_embd",
|
||||
"embeddings.class_embedding", "class_embd",
|
||||
"embeddings.position_embedding", "position_embd",
|
||||
"transformer.layers", "blk",
|
||||
|
||||
"model.projector", "mm",
|
||||
"model.image_newline", "mm.image_newline",
|
||||
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
|
||||
"model.view_seperator", "mm.view_seperator",
|
||||
|
||||
"model.sam_model.patch_embed.proj", "s.patch_embd",
|
||||
"model.sam_model.pos_embed", "s.position_embd",
|
||||
"model.sam_model.blocks", "s.blk",
|
||||
"model.sam_model.neck", "s.neck",
|
||||
"model.sam_model.net_", "s.net_",
|
||||
}
|
||||
}
|
||||
@@ -44,7 +44,10 @@ func (t tensorBase) Kind() uint32 {
|
||||
t.name == "v.positional_embedding_vlm" ||
|
||||
t.name == "v.tile_position_embd.weight" ||
|
||||
t.name == "v.pre_tile_position_embd.weight" ||
|
||||
t.name == "v.post_tile_position_embd.weight" {
|
||||
t.name == "v.post_tile_position_embd.weight" ||
|
||||
t.name == "s.position_embd" ||
|
||||
strings.HasSuffix(t.name, "rel_pos_h") ||
|
||||
strings.HasSuffix(t.name, "rel_pos_w") {
|
||||
// these tensors are always F32
|
||||
return tensorKindFP32
|
||||
}
|
||||
|
||||
@@ -96,7 +96,10 @@ type safetensor struct {
|
||||
|
||||
func (st safetensor) Kind() uint32 {
|
||||
kind := st.tensorBase.Kind()
|
||||
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
|
||||
if st.dtype == "BF16" &&
|
||||
!strings.HasPrefix(st.name, "v.") &&
|
||||
!strings.HasPrefix(st.name, "s.") &&
|
||||
kind != tensorKindFP32 {
|
||||
kind = tensorKindBF16
|
||||
}
|
||||
|
||||
|
||||
@@ -249,6 +249,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"qwen25vl",
|
||||
"qwen3", "qwen3moe",
|
||||
"qwen3vl", "qwen3vlmoe",
|
||||
"deepseekocr",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
|
||||
@@ -173,6 +173,7 @@ type Tensor interface {
|
||||
Cos(ctx Context) Tensor
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context, up ...Tensor) Tensor
|
||||
QuickGELU(ctx Context, up ...Tensor) Tensor
|
||||
SILU(ctx Context, up ...Tensor) Tensor
|
||||
RELU(ctx Context, up ...Tensor) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
@@ -207,6 +208,8 @@ type Tensor interface {
|
||||
Stddev(ctx Context) Tensor
|
||||
Sqr(ctx Context) Tensor
|
||||
Sqrt(ctx Context) Tensor
|
||||
|
||||
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
|
||||
}
|
||||
|
||||
// ScaledDotProductAttention implements a fused attention
|
||||
@@ -372,3 +375,10 @@ const (
|
||||
DTypeI32
|
||||
DTypeMXFP4
|
||||
)
|
||||
|
||||
type SamplingMode int
|
||||
|
||||
const (
|
||||
SamplingModeNearest SamplingMode = iota
|
||||
SamplingModeBilinear
|
||||
)
|
||||
|
||||
@@ -314,7 +314,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
"altup_proj", "altup_unembd_proj",
|
||||
"per_layer_token_embd", "per_layer_model_proj", "per_layer_proj_norm"):
|
||||
createTensor(tensor{source: t}, output.bts, blocks)
|
||||
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm."):
|
||||
case strings.HasPrefix(t.Name, "v.") || strings.HasPrefix(t.Name, "mm.") || strings.HasPrefix(t.Name, "s."):
|
||||
// TODO: assign vision tensors to the gpu if possible
|
||||
createTensor(tensor{source: t}, output.bts, blocks)
|
||||
case contains(t.Name, "rope_freqs", "rope_factors_long", "rope_factors_short"):
|
||||
@@ -1567,6 +1567,16 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
var tt *C.struct_ggml_tensor
|
||||
if len(t2) > 0 {
|
||||
tt = C.ggml_geglu_quick_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t)
|
||||
} else {
|
||||
tt = C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t)
|
||||
}
|
||||
return &Tensor{b: t.b, t: tt}
|
||||
}
|
||||
|
||||
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
if len(t2) > 0 {
|
||||
return &Tensor{
|
||||
@@ -1724,6 +1734,23 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
|
||||
var mode C.uint32_t
|
||||
switch samplingMode {
|
||||
case ml.SamplingModeNearest:
|
||||
mode = C.GGML_SCALE_MODE_NEAREST
|
||||
case ml.SamplingModeBilinear:
|
||||
mode = C.GGML_SCALE_MODE_BILINEAR
|
||||
default:
|
||||
panic("unsupported interpolate mode")
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_interpolate(ctx.(*Context).ctx, t.t, C.int64_t(dims[0]), C.int64_t(dims[1]), C.int64_t(dims[2]), C.int64_t(dims[3]), mode),
|
||||
}
|
||||
}
|
||||
|
||||
// Slice returns a view of the tensor sliced along dim from low to high in step steps.
|
||||
// Slice panics if the dimension is invalid or the slice parameters are out of range.
|
||||
// If dim=0 and step>1, the tensor is a copy rather than a view to ensure proper shape.
|
||||
|
||||
@@ -25,12 +25,15 @@ const (
|
||||
|
||||
// Composite returns an image with the alpha channel removed by drawing over a white background.
|
||||
func Composite(img image.Image) image.Image {
|
||||
dst := image.NewRGBA(img.Bounds())
|
||||
|
||||
white := color.RGBA{255, 255, 255, 255}
|
||||
draw.Draw(dst, dst.Bounds(), &image.Uniform{white}, image.Point{}, draw.Src)
|
||||
draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over)
|
||||
return CompositeColor(img, white)
|
||||
}
|
||||
|
||||
// CompositeColor returns an image with the alpha channel removed by drawing over a white background.
|
||||
func CompositeColor(img image.Image, color color.Color) image.Image {
|
||||
dst := image.NewRGBA(img.Bounds())
|
||||
draw.Draw(dst, dst.Bounds(), &image.Uniform{color}, image.Point{}, draw.Src)
|
||||
draw.Draw(dst, dst.Bounds(), img, img.Bounds().Min, draw.Over)
|
||||
return dst
|
||||
}
|
||||
|
||||
@@ -55,6 +58,31 @@ func Resize(img image.Image, newSize image.Point, method int) image.Image {
|
||||
return dst
|
||||
}
|
||||
|
||||
// Pad returns an image which has been resized to fit within a new size, preserving aspect ratio, and padded with a color.
|
||||
func Pad(img image.Image, newSize image.Point, color color.Color, kernel draw.Interpolator) image.Image {
|
||||
dst := image.NewRGBA(image.Rect(0, 0, newSize.X, newSize.Y))
|
||||
draw.Draw(dst, dst.Bounds(), &image.Uniform{color}, image.Point{}, draw.Src)
|
||||
|
||||
var minPoint, maxPoint image.Point
|
||||
if img.Bounds().Dx() > img.Bounds().Dy() {
|
||||
// landscape
|
||||
height := newSize.X * img.Bounds().Dy() / img.Bounds().Dx()
|
||||
minPoint = image.Point{0, (newSize.Y - height) / 2}
|
||||
maxPoint = image.Point{newSize.X, height + minPoint.Y}
|
||||
} else {
|
||||
// portrait
|
||||
width := newSize.Y * img.Bounds().Dx() / img.Bounds().Dy()
|
||||
minPoint = image.Point{(newSize.X - width) / 2, 0}
|
||||
maxPoint = image.Point{minPoint.X + width, newSize.Y}
|
||||
}
|
||||
|
||||
kernel.Scale(dst, image.Rectangle{
|
||||
Min: minPoint,
|
||||
Max: maxPoint,
|
||||
}, img, img.Bounds(), draw.Over, nil)
|
||||
return dst
|
||||
}
|
||||
|
||||
// Normalize returns a slice of float32 containing each of the r, g, b values for an image normalized around a value.
|
||||
func Normalize(img image.Image, mean, std [3]float32, rescale bool, channelFirst bool) []float32 {
|
||||
var pixelVals []float32
|
||||
|
||||
83
model/models/deepseekocr/imageprocessor.go
Normal file
83
model/models/deepseekocr/imageprocessor.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package deepseekocr
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"image"
|
||||
"image/color"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"golang.org/x/image/draw"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
type ratio struct {
|
||||
x, y int
|
||||
}
|
||||
|
||||
func ProcessImage(ctx ml.Context, bts []byte) (ml.Tensor, ml.Tensor, []int, error) {
|
||||
img, _, err := image.Decode(bytes.NewReader(bts))
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
minNum, maxNum, imageSize, baseSize := 2, 9, 640, 1024
|
||||
var targetRatios []ratio
|
||||
for n := minNum; n <= maxNum; n++ {
|
||||
for i := 1; i <= n; i++ {
|
||||
for j := 1; j <= n; j++ {
|
||||
if i*j <= maxNum && i*j >= minNum && !slices.Contains(targetRatios, ratio{i, j}) {
|
||||
targetRatios = append(targetRatios, ratio{i, j})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
targetRatio := findBestAspectRatio(targetRatios, img.Bounds().Dx(), img.Bounds().Dy(), imageSize)
|
||||
targetWidth, targetHeight := imageSize*targetRatio.x, imageSize*targetRatio.y
|
||||
blocks := targetRatio.x * targetRatio.y
|
||||
|
||||
mean := imageproc.ImageNetStandardMean
|
||||
std := imageproc.ImageNetStandardSTD
|
||||
|
||||
var patches []float32
|
||||
resized := imageproc.Resize(img, image.Point{X: targetWidth, Y: targetHeight}, imageproc.ResizeBilinear)
|
||||
for i := range blocks {
|
||||
patch := image.NewRGBA(image.Rect(0, 0, imageSize, imageSize))
|
||||
draw.Draw(patch, patch.Bounds(), resized, image.Point{
|
||||
X: i % (targetWidth / imageSize) * imageSize,
|
||||
Y: i / (targetWidth / imageSize) * imageSize,
|
||||
}, draw.Over)
|
||||
|
||||
patches = append(patches, imageproc.Normalize(patch, mean, std, true, true)...)
|
||||
}
|
||||
|
||||
img = imageproc.CompositeColor(img, color.Gray{})
|
||||
img = imageproc.Pad(img, image.Point{X: baseSize, Y: baseSize}, color.Gray{127}, draw.BiLinear)
|
||||
|
||||
return ctx.Input().FromFloats(patches, imageSize, imageSize, 3, blocks),
|
||||
ctx.Input().FromFloats(imageproc.Normalize(img, mean, std, true, true), baseSize, baseSize, 3),
|
||||
[]int{targetRatio.x, targetRatio.y},
|
||||
nil
|
||||
}
|
||||
|
||||
func findBestAspectRatio(targetRatios []ratio, width, height, imageSize int) ratio {
|
||||
bestDiff := math.MaxFloat64
|
||||
best := ratio{1, 1}
|
||||
realRatio := float64(width) / float64(height)
|
||||
for _, target := range targetRatios {
|
||||
targetRatio := float64(target.x) / float64(target.y)
|
||||
diff := math.Abs(realRatio - targetRatio)
|
||||
if diff < bestDiff {
|
||||
bestDiff = diff
|
||||
best = target
|
||||
} else if diff == bestDiff {
|
||||
if float64(width*height) > 0.5*float64(imageSize*imageSize*best.x*best.y) {
|
||||
best = target
|
||||
}
|
||||
}
|
||||
}
|
||||
return best
|
||||
}
|
||||
192
model/models/deepseekocr/model.go
Normal file
192
model/models/deepseekocr/model.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package deepseekocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
|
||||
Sam *samModel `gguf:"s"`
|
||||
Vision *visionModel `gguf:"v"`
|
||||
Text *textModel
|
||||
|
||||
ImageNewline ml.Tensor `gguf:"mm.image_newline"`
|
||||
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
|
||||
ViewSeperator ml.Tensor `gguf:"mm.view_seperator"`
|
||||
|
||||
Projector *nn.Linear `gguf:"mm.layers"`
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, bts []byte) ([]input.Multimodal, error) {
|
||||
patches, original, crop, err := ProcessImage(ctx, bts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var outputs []ml.Tensor
|
||||
if true { // TODO: local features if sum(patches) != 0
|
||||
samOutputs := m.Sam.Forward(ctx, patches)
|
||||
visionOutputs := m.Vision.Forward(ctx, patches, samOutputs)
|
||||
|
||||
samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3)
|
||||
visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1)
|
||||
localOutputs := visionOutputs.Concat(ctx, samOutputs, 0)
|
||||
localOutputs = m.Projector.Forward(ctx, localOutputs)
|
||||
|
||||
hw := int(math.Sqrt(float64(localOutputs.Dim(1))))
|
||||
localOutputs = localOutputs.Reshape(ctx, -1, hw, crop[0], crop[1])
|
||||
localOutputs = localOutputs.Permute(ctx, 0, 2, 1, 3)
|
||||
localOutputs = localOutputs.Contiguous(ctx, -1, crop[0]*hw, crop[1]*hw)
|
||||
localOutputs = localOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, localOutputs.Dim(2)), 1)
|
||||
localOutputs = localOutputs.Reshape(ctx, localOutputs.Dim(0), -1)
|
||||
|
||||
outputs = append(outputs, localOutputs)
|
||||
}
|
||||
|
||||
samOutputs := m.Sam.Forward(ctx, original)
|
||||
visionOutputs := m.Vision.Forward(ctx, original, samOutputs)
|
||||
|
||||
samOutputs = samOutputs.Reshape(ctx, -1, samOutputs.Dim(2), samOutputs.Dim(3)).Permute(ctx, 1, 0, 2, 3)
|
||||
visionOutputs = visionOutputs.Slice(ctx, 1, 1, visionOutputs.Dim(1), 1)
|
||||
globalOutputs := visionOutputs.Concat(ctx, samOutputs, 0)
|
||||
globalOutputs = m.Projector.Forward(ctx, globalOutputs)
|
||||
|
||||
hw := int(math.Sqrt(float64(globalOutputs.Dim(1))))
|
||||
globalOutputs = globalOutputs.Reshape(ctx, -1, hw, hw)
|
||||
globalOutputs = globalOutputs.Concat(ctx, m.ImageNewline.Repeat(ctx, 2, globalOutputs.Dim(2)), 1)
|
||||
globalOutputs = globalOutputs.Reshape(ctx, globalOutputs.Dim(0), -1)
|
||||
|
||||
outputs = append(outputs, globalOutputs, m.ViewSeperator)
|
||||
return []input.Multimodal{
|
||||
{Tensor: outputs[0].Stack(ctx, 1, outputs[1:]...)},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
outputs := make([]*input.Input, 0, len(inputs))
|
||||
for i := range inputs {
|
||||
if inputs[i].Multimodal == nil {
|
||||
outputs = append(outputs, inputs[i])
|
||||
continue
|
||||
}
|
||||
|
||||
t := inputs[i].Multimodal[0].Tensor
|
||||
outputs = append(outputs, &input.Input{
|
||||
Token: 128815,
|
||||
Multimodal: inputs[i].Multimodal,
|
||||
MultimodalHash: inputs[i].MultimodalHash,
|
||||
SameBatch: t.Dim(1) - 1,
|
||||
})
|
||||
|
||||
outputs = slices.Grow(outputs, t.Dim(1)-1)
|
||||
outputs = append(outputs, slices.Repeat([]*input.Input{{Token: 128815}}, t.Dim(1)-1)...)
|
||||
}
|
||||
return outputs, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
inputsEmbeds := m.Text.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
for _, mm := range batch.Multimodal {
|
||||
t := mm.Multimodal[0].Tensor
|
||||
ctx.Forward(t.Copy(ctx, inputsEmbeds.View(ctx, mm.Index*inputsEmbeds.Stride(1), t.Dim(0)*t.Dim(1))))
|
||||
}
|
||||
|
||||
hiddenStates := inputsEmbeds
|
||||
for i, block := range m.Text.Blocks {
|
||||
if m.Cache != nil {
|
||||
m.Cache.SetLayer(i)
|
||||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Text.Blocks)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Text.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.Text.OutputNorm.Forward(ctx, hiddenStates, m.Text.Options.eps)
|
||||
return m.Text.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("deepseekocr", func(c fs.Config) (model.Model, error) {
|
||||
textBlocks := make([]textBlock, c.Uint("block_count"))
|
||||
leadingDenseBlockCount := int(c.Uint("leading_dense_block_count", 1))
|
||||
for i := range textBlocks {
|
||||
if i >= leadingDenseBlockCount {
|
||||
textBlocks[i].FeedForward = &textMoe{}
|
||||
} else {
|
||||
textBlocks[i].FeedForward = &textMLP{}
|
||||
}
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: model.NewBytePairEncoding(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
// Split regex into multiple parts (according to DeepSeek3's regex)
|
||||
"\\p{N}{1,3}",
|
||||
`[一-龥-ゟ゠-ヿ]+`,
|
||||
"[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
|
||||
),
|
||||
Text: &textModel{
|
||||
Blocks: textBlocks,
|
||||
Options: textOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
numExperts: int(c.Uint("expert_count")),
|
||||
numExpertsUsed: int(c.Uint("expert_used_count")),
|
||||
ropeBase: c.Float("rope.freq_base", 10_000),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-6),
|
||||
},
|
||||
},
|
||||
Vision: &visionModel{
|
||||
Blocks: make([]visionBlock, c.Uint("vision.block_count")),
|
||||
Options: visionOptions{
|
||||
hiddenSize: int(c.Uint("vision.embedding_length")),
|
||||
numHeads: int(c.Uint("vision.head_count")),
|
||||
imageSize: int(c.Uint("vision.image_size", 224)),
|
||||
patchSize: int(c.Uint("vision.patch_size", 14)),
|
||||
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
|
||||
},
|
||||
},
|
||||
Sam: &samModel{
|
||||
Blocks: make([]samBlock, c.Uint("sam.block_count")),
|
||||
Options: samOptions{
|
||||
hiddenSize: int(c.Uint("sam.embedding_length")),
|
||||
numHeads: int(c.Uint("sam.head_count")),
|
||||
eps: c.Float("sam.attention.layer_norm_epsilon", 1e-6),
|
||||
globalAttentionLayers: c.Ints("sam.global_attention_indexes"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.Text.Shift)
|
||||
return &m, nil
|
||||
})
|
||||
}
|
||||
225
model/models/deepseekocr/model_sam.go
Normal file
225
model/models/deepseekocr/model_sam.go
Normal file
@@ -0,0 +1,225 @@
|
||||
package deepseekocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
type samModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
|
||||
PositionEmbedding ml.Tensor `gguf:"position_embd"`
|
||||
|
||||
Blocks []samBlock `gguf:"blk"`
|
||||
|
||||
Neck *samNeck `gguf:"neck"`
|
||||
Net2 *nn.Conv2D `gguf:"net_2"`
|
||||
Net3 *nn.Conv2D `gguf:"net_3"`
|
||||
|
||||
Options samOptions
|
||||
}
|
||||
|
||||
func (m *samModel) absolutePositionEmbedding(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
source := m.PositionEmbedding.Dim(1)
|
||||
target := hiddenStates.Dim(2)
|
||||
if source != target {
|
||||
positionEmbed := m.PositionEmbedding.Permute(ctx, 2, 0, 1, 3)
|
||||
positionEmbed = positionEmbed.Interpolate(ctx, [4]int{target, target, hiddenStates.Dim(0), 1}, ml.SamplingModeBilinear)
|
||||
return positionEmbed.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
return m.PositionEmbedding
|
||||
}
|
||||
|
||||
func (m *samModel) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
hiddenStates := m.PatchEmbedding.Forward(ctx, t, 16, 16, 0, 0, 1, 1)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
if m.PositionEmbedding != nil {
|
||||
hiddenStates = hiddenStates.Add(ctx, m.absolutePositionEmbedding(ctx, hiddenStates))
|
||||
}
|
||||
|
||||
for i, block := range m.Blocks {
|
||||
var windowSize int
|
||||
if !slices.Contains(m.Options.globalAttentionLayers, int32(i)) {
|
||||
windowSize = 14
|
||||
}
|
||||
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, windowSize, m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
hiddenStates = m.Neck.Forward(ctx, hiddenStates, m.Options)
|
||||
hiddenStates = m.Net2.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
|
||||
hiddenStates = m.Net3.Forward(ctx, hiddenStates, 2, 2, 1, 1, 1, 1)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type samOptions struct {
|
||||
hiddenSize,
|
||||
numHeads int
|
||||
eps float32
|
||||
globalAttentionLayers []int32
|
||||
}
|
||||
|
||||
func (o samOptions) headDim() int {
|
||||
return o.hiddenSize / o.numHeads
|
||||
}
|
||||
|
||||
type samBlock struct {
|
||||
Norm1 *nn.LayerNorm `gguf:"norm1"`
|
||||
Attention *samAttention `gguf:"attn"`
|
||||
Norm2 *nn.LayerNorm `gguf:"norm2"`
|
||||
FeedForward *samMLP `gguf:"mlp"`
|
||||
}
|
||||
|
||||
func (m *samBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, windowSize int, opts samOptions) ml.Tensor {
|
||||
c, w, h := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
|
||||
|
||||
residual := hiddenStates
|
||||
hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
var pw, ph int
|
||||
if windowSize > 0 {
|
||||
pw = (windowSize - hiddenStates.Dim(1)%windowSize) % windowSize
|
||||
ph = (windowSize - hiddenStates.Dim(2)%windowSize) % windowSize
|
||||
if pw > 0 || ph > 0 {
|
||||
hiddenStates = hiddenStates.Pad(ctx, 0, pw, ph, 0)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, (w+pw)/windowSize, windowSize, -1)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, c, windowSize, windowSize, -1)
|
||||
}
|
||||
|
||||
hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts)
|
||||
|
||||
if windowSize > 0 {
|
||||
hiddenStates = hiddenStates.Reshape(ctx, c*windowSize, windowSize, (w+pw)/windowSize, -1)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3)
|
||||
hiddenStates = hiddenStates.Contiguous(ctx, c, w+pw, h+ph, -1)
|
||||
hiddenStates = hiddenStates.Pad(ctx, 0, -pw, -ph, 0)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type samAttention struct {
|
||||
QKV *nn.Linear `gguf:"qkv"`
|
||||
Output *nn.Linear `gguf:"proj"`
|
||||
|
||||
RelativePosition *struct {
|
||||
Height ml.Tensor `gguf:"h"`
|
||||
Width ml.Tensor `gguf:"w"`
|
||||
} `gguf:",pre:rel_pos_"`
|
||||
}
|
||||
|
||||
func relativeCoordinates(ctx ml.Context, qn, kn int) ml.Tensor {
|
||||
s := make([]int32, qn*kn)
|
||||
for i := range qn {
|
||||
for j := range kn {
|
||||
q := i * max(kn/qn, 1)
|
||||
k := j * max(qn/kn, 1)
|
||||
s[i*kn+j] = int32(q - k + (kn-1)*max(qn/kn, 1))
|
||||
}
|
||||
}
|
||||
return ctx.Input().FromInts(s, qn*kn)
|
||||
}
|
||||
|
||||
func relativePositions(ctx ml.Context, positions ml.Tensor, qn, kn int) ml.Tensor {
|
||||
maxRelativeDistance := 2*max(qn, kn) - 1
|
||||
if positions.Dim(1) != maxRelativeDistance {
|
||||
// linear interpolation kernel not available so approx. with bilinear interpolation
|
||||
positions = positions.Interpolate(ctx, [4]int{positions.Dim(0), maxRelativeDistance, 1, 1}, ml.SamplingModeBilinear)
|
||||
}
|
||||
|
||||
rc := relativeCoordinates(ctx, qn, kn)
|
||||
return positions.Rows(ctx, rc).Reshape(ctx, positions.Dim(0), kn, qn)
|
||||
}
|
||||
|
||||
func (m *samAttention) decomposedRelativePositions(ctx ml.Context, query ml.Tensor, qn, kn []int) (ml.Tensor, ml.Tensor) {
|
||||
qh, qw := qn[0], qn[1]
|
||||
kh, kw := kn[0], kn[1]
|
||||
|
||||
rh := relativePositions(ctx, m.RelativePosition.Height, qh, kh)
|
||||
rw := relativePositions(ctx, m.RelativePosition.Width, qw, kw)
|
||||
|
||||
query = query.Contiguous(ctx, query.Dim(0), qw, qh, -1)
|
||||
rh = rh.Mulmat(ctx, query).Reshape(ctx, 1, kh, qh*qw, -1)
|
||||
rw = rw.Mulmat(ctx, query.Permute(ctx, 0, 2, 1, 3)).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, kw, 1, qh*qw, -1)
|
||||
return rh, rw
|
||||
}
|
||||
|
||||
func (m *samAttention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
|
||||
w, h, b := hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)
|
||||
|
||||
qkv := m.QKV.Forward(ctx, hiddenStates)
|
||||
qkv = qkv.Reshape(ctx, opts.headDim(), -1, w*h, b)
|
||||
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
|
||||
query, key, value := chunks[0], chunks[1], chunks[2]
|
||||
|
||||
ctx.Forward(query, key, value)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
rh, rw := m.decomposedRelativePositions(ctx, query, []int{h, w}, []int{h, w})
|
||||
mask := rh.Repeat(ctx, 0, rw.Dim(0)).Add(ctx, rw)
|
||||
mask = mask.Reshape(ctx, h*w, -1, opts.numHeads, b)
|
||||
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
scores := key.MulmatFullPrec(ctx, query)
|
||||
scores = scores.Scale(ctx, 1/math.Sqrt(float64(opts.headDim())))
|
||||
|
||||
scores = scores.Add(ctx, mask)
|
||||
scores = scores.Softmax(ctx)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
attention := value.Mulmat(ctx, scores)
|
||||
attention = attention.Permute(ctx, 0, 2, 1, 3)
|
||||
attention = attention.Contiguous(ctx, -1, w, h, b)
|
||||
return m.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type samMLP struct {
|
||||
Lin1 *nn.Linear `gguf:"lin1"`
|
||||
Lin2 *nn.Linear `gguf:"lin2"`
|
||||
}
|
||||
|
||||
func (m *samMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
|
||||
return m.Lin2.Forward(ctx, m.Lin1.Forward(ctx, hiddenStates).GELU(ctx))
|
||||
}
|
||||
|
||||
type LayerNorm2D struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
Bias ml.Tensor `gguf:"bias"`
|
||||
}
|
||||
|
||||
func (ln *LayerNorm2D) Forward(ctx ml.Context, x ml.Tensor, eps float32) ml.Tensor {
|
||||
x = x.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
u := x.Mean(ctx)
|
||||
d := x.Sub(ctx, u)
|
||||
s := d.Sqr(ctx).Mean(ctx)
|
||||
x = d.Div(ctx, s.Add(ctx, ctx.Input().FromFloats([]float32{eps}, 1)).Sqrt(ctx))
|
||||
x = x.Mul(ctx, ln.Weight).Add(ctx, ln.Bias)
|
||||
return x.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
|
||||
type samNeck struct {
|
||||
C1 *nn.Conv2D `gguf:"0"`
|
||||
LN1 *LayerNorm2D `gguf:"1"`
|
||||
C2 *nn.Conv2D `gguf:"2"`
|
||||
LN2 *LayerNorm2D `gguf:"3"`
|
||||
}
|
||||
|
||||
func (m *samNeck) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts samOptions) ml.Tensor {
|
||||
hiddenStates = m.C1.Forward(ctx, hiddenStates, 1, 1, 0, 0, 1, 1)
|
||||
hiddenStates = m.LN1.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = m.C2.Forward(ctx, hiddenStates, 1, 1, 1, 1, 1, 1)
|
||||
hiddenStates = m.LN2.Forward(ctx, hiddenStates, opts.eps)
|
||||
return hiddenStates
|
||||
}
|
||||
140
model/models/deepseekocr/model_text.go
Normal file
140
model/models/deepseekocr/model_text.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package deepseekocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
type textModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Blocks []textBlock `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output"`
|
||||
|
||||
Options textOptions
|
||||
}
|
||||
|
||||
func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
type textOptions struct {
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads,
|
||||
numExperts,
|
||||
numExpertsUsed int
|
||||
ropeBase,
|
||||
ropeScale,
|
||||
eps float32
|
||||
}
|
||||
|
||||
func (o textOptions) headDim() int {
|
||||
return o.hiddenSize / o.numHeads
|
||||
}
|
||||
|
||||
func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
||||
return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX())
|
||||
}
|
||||
|
||||
type textBlock struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
Attention *textAttention
|
||||
MLPNNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
FeedForward textFeedForward
|
||||
}
|
||||
|
||||
func (m *textBlock) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = m.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = m.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
|
||||
if outputs != nil {
|
||||
hiddenStates = hiddenStates.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = m.MLPNNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates, opts)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
type textAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts textOptions) ml.Tensor {
|
||||
query := m.Query.Forward(ctx, hiddenStates)
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, -1)
|
||||
|
||||
key := m.Key.Forward(ctx, hiddenStates)
|
||||
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
|
||||
|
||||
value := m.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
|
||||
|
||||
query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention = attention.Reshape(ctx, -1, attention.Dim(2))
|
||||
return m.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type textFeedForward interface {
|
||||
Forward(ml.Context, ml.Tensor, textOptions) ml.Tensor
|
||||
}
|
||||
|
||||
type textMoe struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
SharedExperts *textMLP `gguf:",suf:_shexp"`
|
||||
}
|
||||
|
||||
func (m *textMoe) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts textOptions) ml.Tensor {
|
||||
scores := m.Router.Forward(ctx, hiddenStates).Softmax(ctx)
|
||||
indices := scores.TopK(ctx, opts.numExpertsUsed)
|
||||
weights := scores.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, indices)
|
||||
|
||||
experts := hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
experts = m.Gate.Forward(ctx, experts, indices).SILU(ctx, m.Up.Forward(ctx, experts, indices))
|
||||
experts = m.Down.Forward(ctx, experts, indices)
|
||||
experts = experts.Mul(ctx, weights)
|
||||
|
||||
expert := func(i int) ml.Tensor {
|
||||
return experts.View(
|
||||
ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2),
|
||||
)
|
||||
}
|
||||
|
||||
routedStates := expert(0)
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
routedStates = routedStates.Add(ctx, expert(i))
|
||||
}
|
||||
|
||||
sharedStates := m.SharedExperts.Forward(ctx, hiddenStates, opts)
|
||||
return routedStates.Add(ctx, sharedStates)
|
||||
}
|
||||
|
||||
type textMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (m *textMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ textOptions) ml.Tensor {
|
||||
hiddenStates = m.Gate.Forward(ctx, hiddenStates).SILU(ctx, m.Up.Forward(ctx, hiddenStates))
|
||||
return m.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
117
model/models/deepseekocr/model_vision.go
Normal file
117
model/models/deepseekocr/model_vision.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package deepseekocr
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
type visionModel struct {
|
||||
PatchEmbedding *nn.Conv2D `gguf:"patch_embd"`
|
||||
ClassEmbedding ml.Tensor `gguf:"class_embd"`
|
||||
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
|
||||
|
||||
PreLayerNorm *nn.LayerNorm `gguf:"pre_layrnorm"`
|
||||
Blocks []visionBlock `gguf:"blk"`
|
||||
|
||||
Options visionOptions
|
||||
}
|
||||
|
||||
func (m *visionModel) absolutePositionEmbedding(ctx ml.Context, embeds ml.Tensor) ml.Tensor {
|
||||
numPatches := m.Options.imageSize / m.Options.patchSize * m.Options.imageSize / m.Options.patchSize
|
||||
positions := ctx.Arange(0, float32(numPatches+1), 1, ml.DTypeI32)
|
||||
positionEmbeds := m.PositionEmbedding.Forward(ctx, positions)
|
||||
|
||||
source := int(math.Sqrt(float64(positionEmbeds.Dim(1) - 1)))
|
||||
target := int(math.Sqrt(float64(embeds.Dim(1) - 1)))
|
||||
if source != target {
|
||||
newPositionEmbeds := positionEmbeds.Slice(ctx, 1, 1, positionEmbeds.Dim(1), 1)
|
||||
newPositionEmbeds = newPositionEmbeds.Reshape(ctx, -1, source, source)
|
||||
newPositionEmbeds = newPositionEmbeds.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
newPositionEmbeds = newPositionEmbeds.Interpolate(ctx, [4]int{target, target, embeds.Dim(0), 1}, ml.SamplingModeBilinear)
|
||||
newPositionEmbeds = newPositionEmbeds.Permute(ctx, 1, 2, 0, 3)
|
||||
newPositionEmbeds = newPositionEmbeds.Contiguous(ctx, -1, target*target)
|
||||
|
||||
positionEmbeds = positionEmbeds.Slice(ctx, 1, 0, 1, 1).Concat(ctx, newPositionEmbeds, 1)
|
||||
}
|
||||
|
||||
return positionEmbeds
|
||||
}
|
||||
|
||||
func (m *visionModel) Forward(ctx ml.Context, pixelValues, patchEmbeds ml.Tensor) ml.Tensor {
|
||||
if patchEmbeds == nil {
|
||||
patchEmbeds = m.PatchEmbedding.Forward(ctx, pixelValues, m.Options.patchSize, m.Options.patchSize, 0, 0, 1, 1)
|
||||
}
|
||||
|
||||
patchEmbeds = patchEmbeds.Reshape(ctx, -1, patchEmbeds.Dim(2), patchEmbeds.Dim(3))
|
||||
patchEmbeds = patchEmbeds.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
classEmbeds := m.ClassEmbedding.Repeat(ctx, 2, patchEmbeds.Dim(2))
|
||||
embeds := classEmbeds.Concat(ctx, patchEmbeds, 1)
|
||||
embeds = embeds.Add(ctx, m.absolutePositionEmbedding(ctx, embeds))
|
||||
|
||||
hiddenStates := m.PreLayerNorm.Forward(ctx, embeds, m.Options.eps)
|
||||
for _, block := range m.Blocks {
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, m.Options)
|
||||
}
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type visionOptions struct {
|
||||
hiddenSize,
|
||||
numHeads int
|
||||
eps float32
|
||||
|
||||
imageSize, patchSize int
|
||||
}
|
||||
|
||||
func (o visionOptions) headDim() int {
|
||||
return o.hiddenSize / o.numHeads
|
||||
}
|
||||
|
||||
type visionBlock struct {
|
||||
Norm1 *nn.LayerNorm `gguf:"layer_norm1"`
|
||||
Attention *visionAttention `gguf:"self_attn"`
|
||||
Norm2 *nn.LayerNorm `gguf:"layer_norm2"`
|
||||
FeedForward *visionMLP `gguf:"mlp"`
|
||||
}
|
||||
|
||||
func (m *visionBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts visionOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = m.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = m.Attention.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = m.Norm2.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = m.FeedForward.Forward(ctx, hiddenStates)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type visionAttention struct {
|
||||
QKV *nn.Linear `gguf:"qkv_proj"`
|
||||
Output *nn.Linear `gguf:"out_proj"`
|
||||
}
|
||||
|
||||
func (m *visionAttention) Forward(ctx ml.Context, t ml.Tensor, opts visionOptions) ml.Tensor {
|
||||
qkv := m.QKV.Forward(ctx, t)
|
||||
qkv = qkv.Reshape(ctx, opts.headDim(), -1, qkv.Dim(1), qkv.Dim(2))
|
||||
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
|
||||
query, key, value := chunks[0], chunks[1], chunks[2]
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
|
||||
attention = attention.Reshape(ctx, -1, attention.Dim(2), attention.Dim(3))
|
||||
return m.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type visionMLP struct {
|
||||
FC1 *nn.Linear `gguf:"fc1"`
|
||||
FC2 *nn.Linear `gguf:"fc2"`
|
||||
}
|
||||
|
||||
func (m *visionMLP) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
return m.FC2.Forward(ctx, m.FC1.Forward(ctx, t).QuickGELU(ctx))
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package models
|
||||
import (
|
||||
_ "github.com/ollama/ollama/model/models/bert"
|
||||
_ "github.com/ollama/ollama/model/models/deepseek2"
|
||||
_ "github.com/ollama/ollama/model/models/deepseekocr"
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
|
||||
Reference in New Issue
Block a user