Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion docs/en/transforms/embedding.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,58 @@ vectorization_fields {
}
```

**Multi-field Mixing Multimodal Vectorization:**
> Note: Currently, only the `DOUBAO` provider supports multimodal data processing.
```hocon
vectorization_fields {
# Multi-field text
multi_field_text_vector = [product_name, description]

# Multi-field image
multi_field_image_vector = [
{
field = product_image_url
modality = jpeg
format = url
},
{
field = thumbnail_image
modality = png
format = url
}
]

# Multi-field video
multi_field_video_vector = [
{
field = product_video_url
modality = mp4
format = url
},
{
field = promotional_video
modality = mov
format = url
}
]

# Multi-field mix multimodal
multi_field_mix_vector = [
product_name,
{
field = product_image_url
modality = jpeg
format = url
},
{
field = product_video_url
modality = mp4
format = url
}
]
}
```

**Field Specification Formats:**

**Supported Modality Types:**
Expand All @@ -162,7 +214,7 @@ vectorization_fields {
- `binary` - Binary data format

**Automatic Modality Detection:**
When `modality` is not explicitly specified and `format` is not `binary`, the system automatically detects the modality type based on the file suffix of the field value:
When `modality` is not explicitly specified and `format` is `url`, the system automatically detects the modality type based on the file suffix of the field value:

> **Important:** When using multimodal fields (image or video), ensure your model provider supports multimodal embedding. Image and video fields must contain valid URLs or binary data. Currently, `DOUBAO` provider supports multimodal data processing.

Expand Down
54 changes: 53 additions & 1 deletion docs/zh/transforms/embedding.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,58 @@ vectorization_fields {
}
```

**多字段混合多模态向量化:**
> 注意: 目前,仅 `DOUBAO` 提供商支持多模态数据处理
```hocon
vectorization_fields {
# 多字段文本
multi_field_text_vector = [product_name, description]

# 多字段图片
multi_field_image_vector = [
{
field = product_image_url
modality = jpeg
format = url
},
{
field = thumbnail_image
modality = png
format = url
}
]

# 多字段视频
multi_field_video_vector = [
{
field = product_video_url
modality = mp4
format = url
},
{
field = promotional_video
modality = mov
format = url
}
]

# 多字段混合多模态
multi_field_mix_vector = [
product_name,
{
field = product_image_url
modality = jpeg
format = url
},
{
field = product_video_url
modality = mp4
format = url
}
]
}
```

**字段规范格式:**

**支持的模态类型:**
Expand All @@ -147,7 +199,7 @@ vectorization_fields {
- `binary` - 二进制数据格式

**自动模态检测:**
当未显式指定 `modality` 且 `format` 不是 `binary` 时,系统会根据字段值的文件后缀自动检测模态类型:
当未显式指定 `modality` 且 `format` 是 `url` 时,系统会根据字段值的文件后缀自动检测模态类型:

> **重要:** 使用多模态字段(图片或视频)时,请确保您的模型提供商支持多模态 embedding。图片和视频字段必须包含有效的 URL 或二进制数据。目前,`DOUBAO` 提供商支持多模态数据处理。

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,48 @@ transform {
}

product_name_vector = product_name

multi_field_text_vector = [product_name, description]

multi_field_image_vector = [
{
field = product_image_url
modality = jpeg
format = url
},
{
field = thumbnail_image
modality = png
format = url
}
]

multi_field_video_vector = [
{
field = product_video_url
modality = mp4
format = url
},
{
field = promotional_video
modality = mov
format = url
}
]

multi_field_mix_vector = [
product_name,
{
field = product_image_url
modality = jpeg
format = url
},
{
field = product_video_url
modality = mp4
format = url
}
]
}

plugin_output = "multimodal_embedding_output"
Expand Down Expand Up @@ -219,6 +261,42 @@ sink {
}
]
},
{
field_name = multi_field_text_vector
field_type = float_vector
field_value = [
{
rule_type = NOT_NULL
}
]
},
{
field_name = multi_field_image_vector
field_type = float_vector
field_value = [
{
rule_type = NOT_NULL
}
]
},
{
field_name = multi_field_video_vector
field_type = float_vector
field_value = [
{
rule_type = NOT_NULL
}
]
},
{
field_name = multi_field_mix_vector
field_type = float_vector
field_value = [
{
rule_type = NOT_NULL
}
]
},
{
field_name = category
field_type = string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,22 @@
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

@Slf4j
public class EmbeddingTransform extends MultipleFieldOutputTransform {

private final ReadonlyConfig config;
private List<Integer> fieldOriginalIndexes;
private transient Model model;
private Integer dimension;
private boolean isMultimodalFields = false;
private Map<Integer, FieldSpec> fieldSpecMap;
private Map<VectorFieldSpec, List<Integer>> fieldSpecMap;
private List<String> fieldNames;

private final Map<String, TreeMap<Long, byte[]>> binaryFileCache = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -204,30 +204,35 @@ public void open() {
}

private void initOutputFields(SeaTunnelRowType inputRowType, ReadonlyConfig config) {
Map<Integer, FieldSpec> fieldSpecMap = new HashMap<>();
List<String> fieldNames = new ArrayList<>();
Map<String, Object> fieldsConfig =
config.get(EmbeddingTransformConfig.VECTORIZATION_FIELDS);
if (fieldsConfig == null || fieldsConfig.isEmpty()) {
throw new IllegalArgumentException("vectorization_fields configuration is required");
}

for (Map.Entry<String, Object> field : fieldsConfig.entrySet()) {
FieldSpec fieldSpec = new FieldSpec(field);
log.info("Field spec: {}", fieldSpec.toString());
String srcField = fieldSpec.getFieldName();
int srcFieldIndex;
try {
srcFieldIndex = inputRowType.indexOf(srcField);
} catch (IllegalArgumentException e) {
throw TransformCommonError.cannotFindInputFieldError(getPluginName(), srcField);
}
if (fieldSpec.isMultimodalField()) {
isMultimodalFields = true;
List<String> fieldNames = new ArrayList<>();
Map<VectorFieldSpec, List<Integer>> fieldSpecMap = new LinkedHashMap<>();
for (Map.Entry<String, Object> fieldConfig : fieldsConfig.entrySet()) {
VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(fieldConfig);
log.info("Vector field spec: {}", vectorFieldSpec);
List<String> srcFieldNames =
vectorFieldSpec.getSrcFieldSpecs().stream()
.map(SrcFieldSpec::getFieldName)
.collect(Collectors.toList());
List<Integer> srcFieldIndexes = new ArrayList<>();
for (String srcFieldName : srcFieldNames) {
try {
srcFieldIndexes.add(inputRowType.indexOf(srcFieldName));
} catch (IllegalArgumentException e) {
throw TransformCommonError.cannotFindInputFieldsError(
getPluginName(), srcFieldNames);
}
}
fieldSpecMap.put(srcFieldIndex, fieldSpec);
fieldNames.add(field.getKey());
fieldSpecMap.put(vectorFieldSpec, srcFieldIndexes);
fieldNames.add(vectorFieldSpec.getFieldName());
Comment on lines +215 to +232

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isMultimodalFields = vectorFieldSpec.isMultimodalField();

There is a logical issue; currently, only the last value can be obtained.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your time, let me fix it

}
this.isMultimodalFields =
fieldSpecMap.keySet().stream().anyMatch(VectorFieldSpec::isMultimodalField);
this.fieldSpecMap = fieldSpecMap;
this.fieldNames = fieldNames;
}
Expand All @@ -239,19 +244,28 @@ protected Object[] getOutputFieldValues(SeaTunnelRowAccessor inputRow) {
if (MetadataUtil.isBinaryFormat(inputRow)) {
return vectorizationBinaryRow(inputRow);
}
Set<Integer> fieldOriginalIndexes = fieldSpecMap.keySet();
Object[] fieldValues = new Object[fieldOriginalIndexes.size()];
List<ByteBuffer> vectorization;

Set<VectorFieldSpec> vectorFieldSpecs = fieldSpecMap.keySet();
Object[] fieldValues = new Object[vectorFieldSpecs.size()];
int i = 0;

for (Integer fieldOriginalIndex : fieldOriginalIndexes) {
FieldSpec fieldSpec = fieldSpecMap.get(fieldOriginalIndex);
Object value = inputRow.getField(fieldOriginalIndex);
for (VectorFieldSpec vectorFieldSpec : vectorFieldSpecs) {
List<SrcFieldSpec> srcFieldSpecs = vectorFieldSpec.getSrcFieldSpecs();
List<Integer> srcFieldIndexes = fieldSpecMap.get(vectorFieldSpec);
List<SrcField> srcFields = new ArrayList<>();
for (int j = 0; j < srcFieldSpecs.size(); j++) {
srcFields.add(
new SrcField(
srcFieldSpecs.get(j),
inputRow.getField(srcFieldIndexes.get(j))));
}
fieldValues[i++] =
isMultimodalFields ? new MultimodalFieldValue(fieldSpec, value) : value;
isMultimodalFields
? new MultimodalFieldValue(srcFields)
: srcFields.get(0).getFieldValue();
}

vectorization = model.vectorization(fieldValues);
List<ByteBuffer> vectorization = model.vectorization(fieldValues);
return vectorization.toArray();
} catch (Exception e) {
throw new RuntimeException("Failed to data vectorization", e);
Expand Down Expand Up @@ -289,32 +303,34 @@ public boolean isMultimodalFields() {

/** Process a row in binary format: [data, relativePath, partIndex] */
private Object[] vectorizationBinaryRow(SeaTunnelRowAccessor inputRow) throws Exception {

byte[] completeData = processBinaryRow(inputRow);
if (completeData == null) {
return null;
}
Set<Integer> fieldOriginalIndexes = fieldSpecMap.keySet();
Object[] fieldValues = new Object[fieldOriginalIndexes.size()];

Set<VectorFieldSpec> vectorFieldSpecs = fieldSpecMap.keySet();
Object[] fieldValues = new Object[vectorFieldSpecs.size()];
int i = 0;

for (Integer fieldOriginalIndex : fieldOriginalIndexes) {
FieldSpec fieldSpec = fieldSpecMap.get(fieldOriginalIndex);
if (fieldSpec.isBinary()) {
fieldValues[i++] = new MultimodalFieldValue(fieldSpec, completeData);
} else {
log.warn(
"Non-binary field {} configured in binary format data",
fieldSpec.getFieldName());
fieldValues[i++] = null;
for (VectorFieldSpec vectorFieldSpec : vectorFieldSpecs) {
List<SrcFieldSpec> srcFieldSpecs = vectorFieldSpec.getSrcFieldSpecs();
List<SrcField> srcFields = new ArrayList<>();
for (SrcFieldSpec srcFieldSpec : srcFieldSpecs) {
if (srcFieldSpec.isBinary()) {
srcFields.add(new SrcField(srcFieldSpec, completeData));
} else {
log.warn(
"Non-binary field {} configured in binary format data",
srcFieldSpec.getFieldName());
}
}
fieldValues[i++] = srcFields.isEmpty() ? null : new MultimodalFieldValue(srcFields);
}

try {
return model.vectorization(fieldValues).toArray();
} catch (Exception e) {
throw new RuntimeException(
"Failed to vectorize binary data for file: " + inputRow.toString(), e);
throw new RuntimeException("Failed to vectorize binary data for file: " + inputRow, e);
}
}

Expand Down
Loading
Loading