package engine import ( "context" "fmt" "math/rand" "runtime" "sort" "strings" "time" "git.kingecg.top/kingecg/gomog/pkg/types" ) // StreamAggregationEngine 流式聚合引擎 type StreamAggregationEngine struct { store *MemoryStore } // NewStreamAggregationEngine 创建流式聚合引擎 func NewStreamAggregationEngine(store *MemoryStore) *StreamAggregationEngine { return &StreamAggregationEngine{store: store} } // StreamExecute 流式执行聚合管道 func (e *StreamAggregationEngine) StreamExecute( ctx context.Context, collection string, pipeline []types.AggregateStage, opts StreamAggregationOptions, ) (<-chan []types.Document, <-chan error) { if opts.BufferSize <= 0 { opts.BufferSize = 100 } if opts.MaxConcurrency <= 0 { opts.MaxConcurrency = runtime.NumCPU() } resultChan := make(chan []types.Document, opts.BufferSize) errChan := make(chan error, 1) go func() { defer close(resultChan) defer close(errChan) // 获取文档迭代器 docIter, err := e.store.GetDocumentIterator(collection, opts.BufferSize) if err != nil { errChan <- err return } defer docIter.Close() // 分批处理文档 for docIter.HasNext() { select { case <-ctx.Done(): errChan <- ctx.Err() return default: } batch, err := docIter.NextBatch() if err != nil { errChan <- err return } if len(batch) == 0 { continue } // 执行管道处理 processed, err := e.processBatch(ctx, batch, pipeline, opts) if err != nil { errChan <- err return } if len(processed) > 0 { resultChan <- processed } } }() return resultChan, errChan } // processBatch 处理单个批次的文档 func (e *StreamAggregationEngine) processBatch( ctx context.Context, batch []types.Document, pipeline []types.AggregateStage, opts StreamAggregationOptions, ) ([]types.Document, error) { var result []types.Document = batch for _, stage := range pipeline { select { case <-ctx.Done(): return nil, ctx.Err() default: } var err error result, err = e.executeStageStreaming(stage, result, opts) if err != nil { return nil, err } // 如果结果为空,提前终止 if len(result) == 0 { break } } return result, nil } // executeStageStreaming 执行单个阶段的流式处理 func (e *StreamAggregationEngine) executeStageStreaming( stage types.AggregateStage, docs []types.Document, opts StreamAggregationOptions, ) ([]types.Document, error) { // 对于某些操作,我们仍需完整数据集,所以需要特殊处理 switch stage.Stage { case "$match": return e.executeMatch(stage.Spec, docs) case "$project": return e.executeProject(stage.Spec, docs) case "$limit": return e.executeLimit(stage.Spec, docs) case "$skip": return e.executeSkip(stage.Spec, docs) case "$sort": // $sort 需要完整的数据集,所以不能完全流式处理 // 但在批处理中是可以处理的 return e.executeSort(stage.Spec, docs) case "$unwind": return e.executeUnwind(stage.Spec, docs) case "$addFields", "$set": return e.executeAddFields(stage.Spec, docs) case "$unset": return e.executeUnset(stage.Spec, docs) case "$sample": return e.executeSample(stage.Spec, docs) case "$replaceRoot": return e.executeReplaceRoot(stage.Spec, docs) case "$replaceWith": return e.executeReplaceWith(stage.Spec, docs) // 对于需要全局数据的操作,如 $group, $lookup, $graphLookup 等 // 我们需要特殊的处理方式 case "$group": // $group 需要完整的数据集,不能流式处理 // 这里我们返回错误,提示用户使用传统聚合 return nil, fmt.Errorf("$group stage cannot be processed in streaming mode, use regular aggregation instead") case "$lookup": // $lookup 需要另一个集合的完整数据,不能流式处理 return nil, fmt.Errorf("$lookup stage cannot be processed in streaming mode, use regular aggregation instead") case "$graphLookup": // $graphLookup 需要完整数据,不能流式处理 return nil, fmt.Errorf("$graphLookup stage cannot be processed in streaming mode, use regular aggregation instead") // Batch 5 新增阶段 case "$unionWith": // $unionWith 需要另一个集合的完整数据 return nil, fmt.Errorf("$unionWith stage cannot be processed in streaming mode, use regular aggregation instead") case "$redact": return e.executeRedact(stage.Spec, docs) case "$indexStats", "$collStats": // 这些统计操作需要完整数据 return nil, fmt.Errorf("$indexStats and $collStats stages cannot be processed in streaming mode, use regular aggregation instead") case "$out", "$merge": // 输出操作可以处理,但需要在最后阶段 return e.executeOutputStages(stage, docs) default: return docs, nil // 未知阶段,跳过 } } // executeOutputStages 处理输出阶段 func (e *StreamAggregationEngine) executeOutputStages( stage types.AggregateStage, docs []types.Document, ) ([]types.Document, error) { switch stage.Stage { case "$out": return docs, fmt.Errorf("$out not supported in streaming mode") case "$merge": return docs, fmt.Errorf("$merge not supported in streaming mode") default: return docs, nil } } // executeAddFields 执行 $addFields 阶段 func (e *StreamAggregationEngine) executeAddFields(spec interface{}, docs []types.Document) ([]types.Document, error) { // 从 aggregate.go 复制的实现 addFieldsSpec, ok := spec.(map[string]interface{}) if !ok { return docs, nil } var results []types.Document for _, doc := range docs { // 深拷贝文档 newData := deepCopyMap(doc.Data) // 添加字段 for field, expr := range addFieldsSpec { newData[field] = e.evaluateExpression(newData, expr) } results = append(results, types.Document{ ID: doc.ID, Data: newData, }) } return results, nil } // executeUnset 执行 $unset 阶段 func (e *StreamAggregationEngine) executeUnset(spec interface{}, docs []types.Document) ([]types.Document, error) { unsetSpec, ok := spec.([]interface{}) if !ok { // 如果是字符串,转换为数组 if str, isStr := spec.(string); isStr { unsetSpec = []interface{}{str} } else { return docs, nil } } var results []types.Document for _, doc := range docs { // 深拷贝文档 newData := deepCopyMap(doc.Data) // 移除字段 for _, field := range unsetSpec { if fieldName, isStr := field.(string); isStr { delete(newData, fieldName) } } results = append(results, types.Document{ ID: doc.ID, Data: newData, }) } return results, nil } // executeSample 执行 $sample 阶段 func (e *StreamAggregationEngine) executeSample(spec interface{}, docs []types.Document) ([]types.Document, error) { // 从 aggregate.go 复制的实现 sampleSpec, ok := spec.(map[string]interface{}) if !ok { return docs, nil } size, ok := sampleSpec["size"].(float64) if !ok { return docs, nil } count := int(size) if count >= len(docs) { return docs, nil } if count <= 0 { return []types.Document{}, nil } // 使用洗牌算法随机选择 shuffled := make([]types.Document, len(docs)) copy(shuffled, docs) // Fisher-Yates 洗牌算法的变种,只取前 count 个 source := rand.NewSource(time.Now().UnixNano()) rng := rand.New(source) for i := 0; i < count; i++ { j := len(shuffled) - 1 - i r := i + rng.Intn(j-i+1) shuffled[r], shuffled[i] = shuffled[i], shuffled[r] } return shuffled[:count], nil } // executeReplaceRoot 执行 $replaceRoot 阶段 func (e *StreamAggregationEngine) executeReplaceRoot(spec interface{}, docs []types.Document) ([]types.Document, error) { // 从 aggregate.go 复制的实现 replaceRootSpec, ok := spec.(map[string]interface{}) if !ok { return docs, nil } newRootField, ok := replaceRootSpec["newRoot"].(string) if !ok { return docs, nil } var results []types.Document for _, doc := range docs { // 获取新的根对象 newRoot := getNestedValue(doc.Data, newRootField) if newRootMap, ok := newRoot.(map[string]interface{}); ok { results = append(results, types.Document{ ID: doc.ID, Data: newRootMap, }) } else { // 如果不是对象,创建一个包含该值的对象 results = append(results, types.Document{ ID: doc.ID, Data: map[string]interface{}{newRootField: newRoot}, }) } } return results, nil } // executeReplaceWith 执行 $replaceWith 阶段 func (e *StreamAggregationEngine) executeReplaceWith(spec interface{}, docs []types.Document) ([]types.Document, error) { var results []types.Document for _, doc := range docs { // 使用 evaluateExpression 获取新的文档数据 newData := e.evaluateExpression(doc.Data, spec) if newDataMap, ok := newData.(map[string]interface{}); ok { results = append(results, types.Document{ ID: doc.ID, Data: newDataMap, }) } else { // 如果不是对象,创建一个包含该值的对象 results = append(results, types.Document{ ID: doc.ID, Data: map[string]interface{}{"value": newData}, }) } } return results, nil } // executeRedact 执行 $redact 阶段 func (e *StreamAggregationEngine) executeRedact(spec interface{}, docs []types.Document) ([]types.Document, error) { // 这里需要复制 aggregate.go 中的实现 // 为简洁起见,暂时返回错误 return nil, fmt.Errorf("$redact stage not yet implemented in streaming mode") } // evaluateExpression 评估表达式(复制自 aggregate.go) func (e *StreamAggregationEngine) evaluateExpression(data map[string]interface{}, expr interface{}) interface{} { // 复制自 aggregate.go 中的实现 // 处理字段引用(以 $ 开头的字符串) if fieldStr, ok := expr.(string); ok && len(fieldStr) > 0 && fieldStr[0] == '$' { fieldName := fieldStr[1:] // 移除 $ 前缀 return getNestedValue(data, fieldName) } if exprMap, ok := expr.(map[string]interface{}); ok { for op, operand := range exprMap { switch op { case "$concat": return e.concat(operand, data) case "$toUpper": return strings.ToUpper(fmt.Sprintf("%v", e.getFieldValueStr(types.Document{Data: data}, operand))) case "$toLower": return strings.ToLower(fmt.Sprintf("%v", e.getFieldValueStr(types.Document{Data: data}, operand))) case "$add": return e.add(operand, data) case "$multiply": return e.multiply(operand, data) case "$ifNull": return e.ifNull(operand, data) case "$cond": return e.cond(operand, data) // 可以根据需要添加更多操作 } } } return expr } // 以下是一些辅助函数的占位实现 func (e *StreamAggregationEngine) concat(operand interface{}, data map[string]interface{}) interface{} { // 简单实现 if arr, ok := operand.([]interface{}); ok { result := "" for _, item := range arr { evaluated := e.evaluateExpression(data, item) result += fmt.Sprintf("%v", evaluated) } return result } return "" } func (e *StreamAggregationEngine) getFieldValueStr(doc types.Document, field interface{}) string { // 简单实现 if str, ok := e.getFieldValue(doc, field).(string); ok { return str } return "" } func (e *StreamAggregationEngine) getFieldValue(doc types.Document, field interface{}) interface{} { switch f := field.(type) { case string: if len(f) > 0 && f[0] == '$' { return getNestedValue(doc.Data, f[1:]) } return f default: return field } } func (e *StreamAggregationEngine) add(operand interface{}, data map[string]interface{}) interface{} { if arr, ok := operand.([]interface{}); ok { sum := 0.0 for _, item := range arr { evaluated := e.evaluateExpression(data, item) sum += toFloat64(evaluated) } return sum } return 0 } func (e *StreamAggregationEngine) multiply(operand interface{}, data map[string]interface{}) interface{} { if arr, ok := operand.([]interface{}); ok { result := 1.0 for _, item := range arr { evaluated := e.evaluateExpression(data, item) result *= toFloat64(evaluated) } return result } return 0 } func (e *StreamAggregationEngine) ifNull(operand interface{}, data map[string]interface{}) interface{} { if arr, ok := operand.([]interface{}); ok && len(arr) == 2 { evaluatedFirst := e.evaluateExpression(data, arr[0]) if evaluatedFirst != nil { return evaluatedFirst } return e.evaluateExpression(data, arr[1]) } return nil } func (e *StreamAggregationEngine) cond(operand interface{}, data map[string]interface{}) interface{} { if condMap, ok := operand.(map[string]interface{}); ok { ifCond, hasIf := condMap["if"] thenVal, hasThen := condMap["then"] elseVal, hasElse := condMap["else"] if hasIf && hasThen && hasElse { ifVal := e.evaluateExpression(data, ifCond) if isTrue(ifVal) { return e.evaluateExpression(data, thenVal) } return e.evaluateExpression(data, elseVal) } } return nil } // executeMatch 执行 $match 阶段 func (e *StreamAggregationEngine) executeMatch(spec interface{}, docs []types.Document) ([]types.Document, error) { // 从 aggregate.go 复制的实现 var filter map[string]interface{} if f, ok := spec.(types.Filter); ok { filter = f } else if f, ok := spec.(map[string]interface{}); ok { filter = f } else { return docs, nil } var results []types.Document for _, doc := range docs { if MatchFilter(doc.Data, filter) { results = append(results, doc) } } return results, nil } // executeProject 执行 $project 阶段 func (e *StreamAggregationEngine) executeProject(spec interface{}, docs []types.Document) ([]types.Document, error) { // 从 aggregate.go 复制的实现 projectSpec, ok := spec.(map[string]interface{}) if !ok { return docs, nil } var results []types.Document for _, doc := range docs { projected := e.projectDocument(doc.Data, projectSpec) results = append(results, types.Document{ ID: doc.ID, Data: projected, }) } return results, nil } // projectDocument 投影文档 func (e *StreamAggregationEngine) projectDocument(data map[string]interface{}, spec map[string]interface{}) map[string]interface{} { result := make(map[string]interface{}) for field, include := range spec { if field == "_id" { // 特殊处理 _id if isFalse(include) { // 排除 _id } else { result["_id"] = data["_id"] } continue } if isTrue(include) { // 包含字段 result[field] = getNestedValue(data, field) } else if isFalse(include) { // 排除字段(在包含模式下不处理) continue } else { // 表达式 result[field] = e.evaluateExpression(data, include) } } return result } // executeLimit 执行 $limit 阶段 func (e *StreamAggregationEngine) executeLimit(spec interface{}, docs []types.Document) ([]types.Document, error) { // 从 aggregate.go 复制的实现 limit := 0 switch l := spec.(type) { case int: limit = l case int64: limit = int(l) case float64: limit = int(l) } if limit <= 0 || limit >= len(docs) { return docs, nil } return docs[:limit], nil } // executeSkip 执行 $skip 阶段 func (e *StreamAggregationEngine) executeSkip(spec interface{}, docs []types.Document) ([]types.Document, error) { // 从 aggregate.go 复制的实现 skip := 0 switch s := spec.(type) { case int: skip = s case int64: skip = int(s) case float64: skip = int(s) } if skip <= 0 { return docs, nil } if skip >= len(docs) { return []types.Document{}, nil } return docs[skip:], nil } // executeUnwind 执行 $unwind 阶段 func (e *StreamAggregationEngine) executeUnwind(spec interface{}, docs []types.Document) ([]types.Document, error) { // 从 aggregate.go 复制的实现 var path string var preserveNull bool switch s := spec.(type) { case string: path = s case map[string]interface{}: if p, ok := s["path"].(string); ok { path = p } if pn, ok := s["preserveNullAndEmptyArrays"].(bool); ok { preserveNull = pn } } if path == "" || path[0] != '$' { return docs, nil } fieldPath := path[1:] var results []types.Document for _, doc := range docs { arr := getNestedValue(doc.Data, fieldPath) if arr == nil { if preserveNull { results = append(results, doc) } continue } array, ok := arr.([]interface{}) if !ok || len(array) == 0 { if preserveNull { results = append(results, doc) } continue } for _, item := range array { newData := deepCopyMap(doc.Data) setNestedValue(newData, fieldPath, item) results = append(results, types.Document{ ID: doc.ID, Data: newData, }) } } return results, nil } // executeSort 执行 $sort 阶段 func (e *StreamAggregationEngine) executeSort(spec interface{}, docs []types.Document) ([]types.Document, error) { // 从 aggregate.go 复制的实现 sortSpec, ok := spec.(map[string]interface{}) if !ok { return docs, nil } // 转换为排序字段映射 sortFields := make(map[string]int) for field, direction := range sortSpec { dir := 1 switch d := direction.(type) { case int: dir = d case int64: dir = int(d) case float64: dir = int(d) } sortFields[field] = dir } // 创建可排序的副本 sorted := make([]types.Document, len(docs)) copy(sorted, docs) sort.Slice(sorted, func(i, j int) bool { return e.compareDocs(sorted[i], sorted[j], sortFields) }) return sorted, nil } // compareDocs 比较两个文档 func (e *StreamAggregationEngine) compareDocs(a, b types.Document, sortFields map[string]int) bool { for field, dir := range sortFields { valA := getNestedValue(a.Data, field) valB := getNestedValue(b.Data, field) cmp := compareValues(valA, valB) if cmp != 0 { if dir < 0 { return cmp > 0 } return cmp < 0 } } return false }