Bug description
In VertexAIEmbeddingUtils, the property taskType
is used to indicate the embedding action. This is supposed to default to RETRIEVAL_DOCUMENT
but because the wrong property is specified, Vertex ends up using RETRIEVAL_QUERY
instead. According to the documentation, the expected property is task_type
.
Environment Java 21 Spring AI version: 1.0.0-M5 No vector store used.
Steps to reproduce
1) Generate an embedding using EmbeddingModel (configured to the default values, which should be RETRIEVAL_DOCUMENT according to this). Capture the embedding.
- Use Google native SDK to generate embedding using RETRIEVAL_DOCUMENT. Confirm that the embeddings do not match, and show only 85~93% dot product similarity
- Use Python native SDK to generate embedding using RETRIEVAL_QUERY. Confirm that the embeddings match.
- Use Python native SDK to generate embedding without specifying a task type (which uses RETRIEVAL_QUERY by default). Confirm that the embeddings match.
2) Generate an embedding using EmbeddingModel (configured manually to RETRIEVAL_DOCUMENT via spring.ai.vertex.ai.embedding.text.options.task-type
). Capture the embedding.
- Use Google native SDK to generate embedding using RETRIEVAL_DOCUMENT. Confirm that the embeddings do not match.
- Use Google native SDK to generate embedding using RETRIEVAL_QUERY. Confirm that the embeddings match.
- Use Google native SDK to generate embedding without specifying a task type (which uses RETRIEVAL_QUERY by default). Confirm that the embeddings match.
By using a forked version of VertexAIEmbeddingUtils with the property using task_type
, I get the expected behavior.
Expected behavior Using Spring AI abstraction should return an embedding that matches the requested task_type.
Minimal Complete Reproducible example
package com.example;
import static java.util.stream.Collectors.toList;
import java.util.Arrays;
import java.util.List;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vertexai.embedding.text.VertexAiTextEmbeddingOptions;
import org.springframework.context.annotation.Bean;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import com.google.cloud.aiplatform.v1.EndpointName;
import com.google.cloud.aiplatform.v1.PredictRequest;
import com.google.cloud.aiplatform.v1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1.PredictionServiceSettings;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
@RequiredArgsConstructor
@Component
public class BugTester {
private static final String MODEL_ID = "text-embedding-005";
private static final String LOCATION = "us-central1";
private final EmbeddingModel embeddingModel;
float[] getEmbeddingUsingSpringAI(String text, VertexAiTextEmbeddingOptions.TaskType taskType) {
var options = VertexAiTextEmbeddingOptions.builder().model(MODEL_ID).taskType(taskType).build();
var request = new org.springframework.ai.embedding.EmbeddingRequest(List.of(text), options);
var response = embeddingModel.call(request);
return response.getResults().get(0).getOutput();
}
@SneakyThrows
float[] getEmbeddingUsingGoogleSdk(String text, VertexAiTextEmbeddingOptions.TaskType taskType) {
var endpoint = "%s-aiplatform.googleapis.com:443".formatted(LOCATION);
var project = "YOUR-GCP-PROJECT-ID";
var settings = PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build();
var endpointName = EndpointName.ofProjectLocationPublisherModelName(project, LOCATION, "google", MODEL_ID);
try (PredictionServiceClient client = PredictionServiceClient.create(settings)) {
PredictRequest.Builder request =
PredictRequest.newBuilder().setEndpoint(endpointName.toString());
request.addInstances(valueOf(Struct.newBuilder()
.putFields("content", valueOf(text))
.putFields("task_type", valueOf(taskType.name()))
.build()));
var prediction = client.predict(request.build()).getPredictionsList().getFirst();
Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings");
Value values = embeddings.getStructValue().getFieldsOrThrow("values");
var floats = values.getListValue().getValuesList().stream()
.map(Value::getNumberValue)
.map(Double::floatValue)
.collect(toList());
var float_array = new float[floats.size()];
for (int i = 0; i < floats.size(); i++) {
float_array[i] = floats.get(i).floatValue();
}
return float_array;
}
}
public void compare(String text) {
var springAi = getEmbeddingUsingSpringAI(text, VertexAiTextEmbeddingOptions.TaskType.RETRIEVAL_DOCUMENT);
var nativeSdk = getEmbeddingUsingGoogleSdk(text, VertexAiTextEmbeddingOptions.TaskType.RETRIEVAL_DOCUMENT);
Assert.isTrue(Arrays.equals(springAi, nativeSdk), "Spring AI != Native");
}
private static Value valueOf(String text) {
return Value.newBuilder().setStringValue(text).build();
}
private static Value valueOf(Struct struct) {
return Value.newBuilder().setStructValue(struct).build();
}
}