From 2c2278ce852cfcef5a5bcc6ccfe1225bd105bd76 Mon Sep 17 00:00:00 2001 From: Alexandr Gorshenin Date: Wed, 18 Feb 2026 11:50:35 +0000 Subject: [PATCH 1/8] Switch BulkUpsertDate from class to interface --- pom.xml | 15 ++++++++++ table/pom.xml | 12 +++++++- .../src/main/java/tech/ydb/table/Session.java | 2 +- .../tech/ydb/table/query/BulkUpsertData.java | 19 ++++-------- .../ydb/table/query/BulkUpsertProtoData.java | 29 +++++++++++++++++++ .../ydb/table/integration/ReadTableTest.java | 2 +- .../ydb/table/integration/ValuesReadTest.java | 1 - 7 files changed, 63 insertions(+), 17 deletions(-) create mode 100644 table/src/main/java/tech/ydb/table/query/BulkUpsertProtoData.java diff --git a/pom.xml b/pom.xml index 410e9024f..c7589ba9b 100644 --- a/pom.xml +++ b/pom.xml @@ -41,6 +41,7 @@ 2.8.9 5.11.0 1.19.3 + 18.3.0 @@ -87,6 +88,17 @@ ${gson.version} + + org.apache.arrow + arrow-vector + ${apache.arrow.version} + + + org.apache.arrow + arrow-memory-netty + ${apache.arrow.version} + + junit @@ -242,6 +254,9 @@ argLine + + + diff --git a/table/pom.xml b/table/pom.xml index 358f7dcf7..6cbb69458 100644 --- a/table/pom.xml +++ b/table/pom.xml @@ -29,12 +29,22 @@ ydb-sdk-common + + org.apache.arrow + arrow-vector + true + + + org.apache.arrow + arrow-memory-netty + true + + tech.ydb.test ydb-junit4-support test - org.apache.logging.log4j log4j-slf4j-impl diff --git a/table/src/main/java/tech/ydb/table/Session.java b/table/src/main/java/tech/ydb/table/Session.java index 7793c972e..d4da4e7a1 100644 --- a/table/src/main/java/tech/ydb/table/Session.java +++ b/table/src/main/java/tech/ydb/table/Session.java @@ -199,7 +199,7 @@ default CompletableFuture executeScanQuery(String query, Params params, CompletableFuture> keepAlive(KeepAliveSessionSettings settings); default CompletableFuture executeBulkUpsert(String tablePath, ListValue rows, BulkUpsertSettings settings) { - return executeBulkUpsert(tablePath, new BulkUpsertData(rows), settings); + return executeBulkUpsert(tablePath, BulkUpsertData.fromRows(rows), settings); } CompletableFuture executeBulkUpsert(String tablePath, BulkUpsertData data, BulkUpsertSettings settings); diff --git a/table/src/main/java/tech/ydb/table/query/BulkUpsertData.java b/table/src/main/java/tech/ydb/table/query/BulkUpsertData.java index 73d170ffc..96dee3b1d 100644 --- a/table/src/main/java/tech/ydb/table/query/BulkUpsertData.java +++ b/table/src/main/java/tech/ydb/table/query/BulkUpsertData.java @@ -9,21 +9,14 @@ * * @author Aleksandr Gorshenin */ -public class BulkUpsertData { - private final ValueProtos.TypedValue rows; +public interface BulkUpsertData { + void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder); - public BulkUpsertData(ListValue rows) { - this.rows = ValueProtos.TypedValue.newBuilder() - .setType(rows.getType().toPb()) - .setValue(rows.toPb()) - .build(); + static BulkUpsertData fromRows(ListValue list) { + return new BulkUpsertProtoData(list); } - public BulkUpsertData(ValueProtos.TypedValue rows) { - this.rows = rows; - } - - public void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder) { - builder.setRows(rows); + static BulkUpsertData fromProto(ValueProtos.TypedValue rows) { + return new BulkUpsertProtoData(rows); } } diff --git a/table/src/main/java/tech/ydb/table/query/BulkUpsertProtoData.java b/table/src/main/java/tech/ydb/table/query/BulkUpsertProtoData.java new file mode 100644 index 000000000..17d575f46 --- /dev/null +++ b/table/src/main/java/tech/ydb/table/query/BulkUpsertProtoData.java @@ -0,0 +1,29 @@ +package tech.ydb.table.query; + +import tech.ydb.proto.ValueProtos; +import tech.ydb.proto.table.YdbTable; +import tech.ydb.table.values.ListValue; + +/** + * + * @author Aleksandr Gorshenin + */ +public class BulkUpsertProtoData implements BulkUpsertData { + private final ValueProtos.TypedValue rows; + + public BulkUpsertProtoData(ListValue rows) { + this.rows = ValueProtos.TypedValue.newBuilder() + .setType(rows.getType().toPb()) + .setValue(rows.toPb()) + .build(); + } + + public BulkUpsertProtoData(ValueProtos.TypedValue rows) { + this.rows = rows; + } + + @Override + public void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder) { + builder.setRows(rows); + } +} diff --git a/table/src/test/java/tech/ydb/table/integration/ReadTableTest.java b/table/src/test/java/tech/ydb/table/integration/ReadTableTest.java index 1d736c3af..d86996b49 100644 --- a/table/src/test/java/tech/ydb/table/integration/ReadTableTest.java +++ b/table/src/test/java/tech/ydb/table/integration/ReadTableTest.java @@ -99,7 +99,7 @@ public static void prepareTable() { }).collect(Collectors.toList()); retryCtx.supplyStatus(session -> session.executeBulkUpsert(tablePath, - new BulkUpsertData(ProtoValue.toTypedValue(ListType.of(batchType).newValue(batchData))) + BulkUpsertData.fromProto(ProtoValue.toTypedValue(ListType.of(batchType).newValue(batchData))) )).join().expectSuccess("bulk upsert problem in table " + tablePath); } diff --git a/table/src/test/java/tech/ydb/table/integration/ValuesReadTest.java b/table/src/test/java/tech/ydb/table/integration/ValuesReadTest.java index ba91f67f4..89d1faf07 100644 --- a/table/src/test/java/tech/ydb/table/integration/ValuesReadTest.java +++ b/table/src/test/java/tech/ydb/table/integration/ValuesReadTest.java @@ -468,7 +468,6 @@ public void date32datetime64timestamp64interval64() { @Test public void timestamp64ReadTest() { - System.out.println(Instant.ofEpochSecond(-4611669897600L)); DataQueryResult result = CTX.supplyResult( s -> s.executeDataQuery("SELECT " + "Timestamp64('-144169-01-01T00:00:00Z') as t1," From b53e4836057bda32657e0fd533262f186d6748cc Mon Sep 17 00:00:00 2001 From: Alexandr Gorshenin Date: Wed, 18 Feb 2026 12:01:45 +0000 Subject: [PATCH 2/8] Added impelementations for CSV and ApacheArrow formats --- .../ydb/table/query/BulkUpsertArrowData.java | 35 +++++++++++++++++++ .../ydb/table/query/BulkUpsertCsvData.java | 23 ++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 table/src/main/java/tech/ydb/table/query/BulkUpsertArrowData.java create mode 100644 table/src/main/java/tech/ydb/table/query/BulkUpsertCsvData.java diff --git a/table/src/main/java/tech/ydb/table/query/BulkUpsertArrowData.java b/table/src/main/java/tech/ydb/table/query/BulkUpsertArrowData.java new file mode 100644 index 000000000..1b2d23e8b --- /dev/null +++ b/table/src/main/java/tech/ydb/table/query/BulkUpsertArrowData.java @@ -0,0 +1,35 @@ +package tech.ydb.table.query; + +import com.google.protobuf.ByteString; + +import tech.ydb.proto.formats.YdbFormats; +import tech.ydb.proto.table.YdbTable; + +/** + * + * @author Aleksandr Gorshenin + */ +public class BulkUpsertArrowData implements BulkUpsertData { + private final ByteString schema; + private final ByteString data; + + public BulkUpsertArrowData(ByteString schema, ByteString data) { + this.schema = schema; + this.data = data; + } + + public ByteString getSchema() { + return schema; + } + + public ByteString getData() { + return data; + } + + @Override + public void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder) { + builder.setArrowBatchSettings( + YdbFormats.ArrowBatchSettings.newBuilder().setSchema(schema).build() + ).setData(data); + } +} diff --git a/table/src/main/java/tech/ydb/table/query/BulkUpsertCsvData.java b/table/src/main/java/tech/ydb/table/query/BulkUpsertCsvData.java new file mode 100644 index 000000000..bb9c327fd --- /dev/null +++ b/table/src/main/java/tech/ydb/table/query/BulkUpsertCsvData.java @@ -0,0 +1,23 @@ +package tech.ydb.table.query; + +import com.google.protobuf.ByteString; + +import tech.ydb.proto.formats.YdbFormats; +import tech.ydb.proto.table.YdbTable; + +/** + * + * @author Aleksandr Gorshenin + */ +public class BulkUpsertCsvData implements BulkUpsertData { + private final ByteString data; + + public BulkUpsertCsvData(ByteString data) { + this.data = data; + } + + @Override + public void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder) { + builder.setCsvSettings(YdbFormats.CsvSettings.newBuilder().build()).setData(data); + } +} From fb98b8744c01e16fbe32855b02bf4d12937890e9 Mon Sep 17 00:00:00 2001 From: Alexandr Gorshenin Date: Wed, 18 Feb 2026 12:04:46 +0000 Subject: [PATCH 3/8] Added ApacheArrowWriter helper to build arrow batches --- .../ydb/table/query/ApacheArrowWriter.java | 871 ++++++++++++++++++ .../tech/ydb/table/utils/LittleEndian.java | 3 + .../table/query/ApacheArrowWriterTest.java | 348 +++++++ 3 files changed, 1222 insertions(+) create mode 100644 table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java create mode 100644 table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java diff --git a/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java b/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java new file mode 100644 index 000000000..072541917 --- /dev/null +++ b/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java @@ -0,0 +1,871 @@ +package tech.ydb.table.query; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.function.Function; + +import com.google.protobuf.ByteString; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; + +import tech.ydb.table.utils.LittleEndian; +import tech.ydb.table.values.DecimalValue; +import tech.ydb.table.values.PrimitiveType; +import tech.ydb.table.values.Type; +import tech.ydb.table.values.proto.ProtoValue; + +/** + * + * @author Aleksandr Gorshenin + */ +public class ApacheArrowWriter implements AutoCloseable { + public interface Batch { + Row writeNextRow(); + BulkUpsertArrowData buildBatch() throws IOException; + } + + public interface Row { + void writeNull(String column); + + void writeBool(String column, boolean value); + + void writeInt8(String column, byte value); + void writeInt16(String column, short value); + void writeInt32(String column, int value); + void writeInt64(String column, long value); + + void writeUint8(String column, int value); + void writeUint16(String column, int value); + void writeUint32(String column, long value); + void writeUint64(String column, long value); + + void writeFloat(String column, float value); + void writeDouble(String column, double value); + + void writeText(String column, String text); + void writeJson(String column, String json); + void writeJsonDocument(String column, String jsonDocument); + + void writeBytes(String column, byte[] bytes); + void writeYson(String column, byte[] yson); + + void writeUuid(String column, UUID uuid); + + void writeDate(String column, LocalDate date); + void writeDatetime(String column, LocalDateTime datetime); + void writeTimestamp(String column, Instant instant); + void writeInterval(String column, Duration interval); + + void writeDate32(String column, LocalDate date32); + void writeDatetime64(String column, LocalDateTime datetime64); + void writeTimestamp64(String column, Instant instant64); + void writeInterval64(String column, Duration interval64); + + void writeDecimal(String column, DecimalValue value); + } + + private final VectorSchemaRoot vsr; + private final Map> columns = new HashMap<>(); + + private ApacheArrowWriter(BufferAllocator allocator, List columnsList) { + FieldVector[] vectors = new FieldVector[columnsList.size()]; + for (int idx = 0; idx < columnsList.size(); idx += 1) { + ColumnInfo column = columnsList.get(idx); + Column vector = column.createVector(allocator); + vectors[idx] = vector.vector; + columns.put(column.name, vector); + } + + this.vsr = VectorSchemaRoot.of(vectors); + } + + public Batch createNewBatch(int estimatedRowsCount) { + // reset all + for (Column column: columns.values()) { + column.allocateNew(estimatedRowsCount); + } + return new BatchImpl(); + } + + @Override + public void close() { + vsr.close(); + + for (FieldVector field: vsr.getFieldVectors()) { + field.close(); + } + } + + private class BatchImpl implements Batch { + private int rowIndex = 0; + + @Override + public Row writeNextRow() { + return new RowImpl(rowIndex++); + } + + @Override + public BulkUpsertArrowData buildBatch() throws IOException { + vsr.setRowCount(rowIndex); + return new BulkUpsertArrowData(serializeSchema(), serializeBatch()); + } + + private ByteString serializeSchema() throws IOException { + try (ByteString.Output out = ByteString.newOutput()) { + try (WriteChannel channel = new WriteChannel(Channels.newChannel(out))) { + MessageSerializer.serialize(channel, vsr.getSchema()); + return out.toByteString(); + } + } + } + + private ByteString serializeBatch() throws IOException { + try (ByteString.Output out = ByteString.newOutput()) { + try (WriteChannel channel = new WriteChannel(Channels.newChannel(out))) { + VectorUnloader loader = new VectorUnloader(vsr); + try (ArrowRecordBatch batch = loader.getRecordBatch()) { + MessageSerializer.serialize(channel, batch); + return out.toByteString(); + } + } + } + } + } + + private class RowImpl implements Row { + private final int rowIndex; + + RowImpl(int rowIndex) { + this.rowIndex = rowIndex; + } + + private Column find(String column) { + Column vector = columns.get(column); + if (vector == null) { + throw new IllegalArgumentException("Column '" + column + "' not found"); + } + return vector; + } + + @Override + public void writeNull(String column) { + find(column).writeNull(rowIndex); + } + + @Override + public void writeBool(String column, boolean value) { + find(column).writeBool(rowIndex, value); + } + + @Override + public void writeInt8(String column, byte value) { + find(column).writeInt8(rowIndex, value); + } + + @Override + public void writeInt16(String column, short value) { + find(column).writeInt16(rowIndex, value); + } + + @Override + public void writeInt32(String column, int value) { + find(column).writeInt32(rowIndex, value); + } + + @Override + public void writeInt64(String column, long value) { + find(column).writeInt64(rowIndex, value); + } + + @Override + public void writeUint8(String column, int value) { + find(column).writeUint8(rowIndex, value); + } + + @Override + public void writeUint16(String column, int value) { + find(column).writeUint16(rowIndex, value); + } + + @Override + public void writeUint32(String column, long value) { + find(column).writeUint32(rowIndex, value); + } + + @Override + public void writeUint64(String column, long value) { + find(column).writeUint64(rowIndex, value); + } + + @Override + public void writeFloat(String column, float value) { + find(column).writeFloat(rowIndex, value); + } + + @Override + public void writeDouble(String column, double value) { + find(column).writeDouble(rowIndex, value); + } + + @Override + public void writeText(String column, String text) { + find(column).writeText(rowIndex, text); + } + + @Override + public void writeJson(String column, String json) { + find(column).writeJson(rowIndex, json); + } + + @Override + public void writeJsonDocument(String column, String jsonDocument) { + find(column).writeJsonDocument(rowIndex, jsonDocument); + } + + @Override + public void writeBytes(String column, byte[] bytes) { + find(column).writeBytes(rowIndex, bytes); + } + + @Override + public void writeYson(String column, byte[] yson) { + find(column).writeYson(rowIndex, yson); + } + + @Override + public void writeUuid(String column, UUID uuid) { + find(column).writeUuid(rowIndex, uuid); + } + + @Override + public void writeDate(String column, LocalDate date) { + find(column).writeDate(rowIndex, date); + } + + @Override + public void writeDatetime(String column, LocalDateTime datetime) { + find(column).writeDatetime(rowIndex, datetime); + } + + @Override + public void writeTimestamp(String column, Instant instant) { + find(column).writeTimestamp(rowIndex, instant); + } + + @Override + public void writeInterval(String column, Duration interval) { + find(column).writeInterval(rowIndex, interval); + } + + @Override + public void writeDate32(String column, LocalDate date32) { + find(column).writeDate32(rowIndex, date32); + } + + @Override + public void writeDatetime64(String column, LocalDateTime datetime64) { + find(column).writeDatetime64(rowIndex, datetime64); + } + + @Override + public void writeTimestamp64(String column, Instant instant64) { + find(column).writeTimestamp64(rowIndex, instant64); + } + + @Override + public void writeInterval64(String column, Duration interval64) { + find(column).writeInterval64(rowIndex, interval64); + } + + @Override + public void writeDecimal(String column, DecimalValue value) { + find(column).writeDecimal(rowIndex, value); + } + } + + private abstract static class Column { + protected final Field field; + protected final Type type; + protected final T vector; + + Column(Field field, Type type, T vector) { + this.field = field; + this.type = type; + this.vector = vector; + } + + protected IllegalStateException error(String method) { + return new IllegalStateException("cannot call " + method + ", actual type: " + type); + } + + public abstract void allocateNew(int estimated); + + void writeNull(int rowIndex) { + if (field.isNullable()) { + vector.setNull(rowIndex); + } else { + throw error("writeNull"); + } + } + + void writeBool(int rowIndex, boolean value) { + throw error("writeBool"); + } + + void writeInt8(int rowIndex, byte value) { + throw error("writeInt8"); + } + + void writeInt16(int rowIndex, short value) { + throw error("writeInt16"); + } + + void writeInt32(int rowIndex, int value) { + throw error("writeInt32"); + } + + void writeInt64(int rowIndex, long value) { + throw error("writeInt64"); + } + + void writeUint8(int rowIndex, int value) { + throw error("writeUint8"); + } + + void writeUint16(int rowIndex, int value) { + throw error("writeUint16"); + } + + void writeUint32(int rowIndex, long value) { + throw error("writeUint32"); + } + + void writeUint64(int rowIndex, long value) { + throw error("writeUint64"); + } + + void writeFloat(int rowIndex, float value) { + throw error("writeFloat"); + } + + void writeDouble(int rowIndex, double value) { + throw error("writeDouble"); + } + + void writeText(int rowIndex, String text) { + throw error("writeText"); + } + + void writeJson(int rowIndex, String json) { + throw error("writeJson"); + } + + void writeJsonDocument(int rowIndex, String jsonDocument) { + throw error("writeJsonDocument"); + } + + void writeBytes(int rowIndex, byte[] bytes) { + throw error("writeBytes"); + } + + void writeYson(int rowIndex, byte[] yson) { + throw error("writeYson"); + } + + void writeUuid(int rowIndex, UUID yson) { + throw error("writeUuid"); + } + + void writeDate(int rowIndex, LocalDate date) { + throw error("writeDate"); + } + + void writeDatetime(int rowIndex, LocalDateTime datetime) { + throw error("writeDatetime"); + } + + void writeTimestamp(int rowIndex, Instant instant) { + throw error("writeTimestamp"); + } + + void writeInterval(int rowIndex, Duration interval) { + throw error("writeInterval"); + } + + void writeDate32(int rowIndex, LocalDate date32) { + throw error("writeDate32"); + } + + void writeDatetime64(int rowIndex, LocalDateTime datetime64) { + throw error("writeDatetime64"); + } + + void writeTimestamp64(int rowIndex, Instant instant64) { + throw error("writeTimestamp64"); + } + + void writeInterval64(int rowIndex, Duration interval64) { + throw error("writeInterval64"); + } + + void writeDecimal(int rowIndex, DecimalValue value) { + throw error("writeDecimal"); + } + } + + private static class FixedWidthColumn extends Column { + + FixedWidthColumn(Field field, Type type, T vector) { + super(field, type, vector); + } + + @Override + public void allocateNew(int estimated) { + vector.allocateNew(estimated); + } + } + + private static class VariableWidthColumn extends Column { + + VariableWidthColumn(Field field, Type type, T vector) { + super(field, type, vector); + } + + @Override + public void allocateNew(int estimated) { + vector.allocateNew(estimated); + } + } + + private static class TinyIntColumn extends FixedWidthColumn { + + TinyIntColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new TinyIntVector(field, allocator)); + } + + @Override + void writeBool(int rowIndex, boolean value) { + if (type != PrimitiveType.Bool) { + throw error("writeBool"); + } + vector.setSafe(rowIndex, value ? 1 : 0); + } + + @Override + void writeInt8(int rowIndex, byte value) { + if (type != PrimitiveType.Int8) { + throw error("writeInt8"); + } + vector.setSafe(rowIndex, value); + } + + @Override + void writeUint8(int rowIndex, int value) { + if (type != PrimitiveType.Uint8) { + throw error("writeUint8"); + } + vector.setSafe(rowIndex, value); + } + } + + private static class SmallIntColumn extends FixedWidthColumn { + + SmallIntColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new SmallIntVector(field, allocator)); + } + + @Override + void writeInt16(int rowIndex, short value) { + if (type != PrimitiveType.Int16) { + throw error("writeInt16"); + } + vector.setSafe(rowIndex, value); + } + + @Override + void writeUint16(int rowIndex, int value) { + if (type != PrimitiveType.Uint16) { + throw error("writeUint16"); + } + vector.setSafe(rowIndex, value); + } + + @Override + void writeDate(int rowIndex, LocalDate date) { + if (type != PrimitiveType.Date) { + throw error("writeDate"); + } + vector.setSafe(rowIndex, (int) date.toEpochDay()); + } + } + + private static class IntColumn extends FixedWidthColumn { + + IntColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new IntVector(field, allocator)); + } + + @Override + void writeInt32(int rowIndex, int value) { + if (type != PrimitiveType.Int32) { + throw error("writeInt32"); + } + vector.setSafe(rowIndex, value); + } + + @Override + void writeUint32(int rowIndex, long value) { + if (type != PrimitiveType.Uint32) { + throw error("writeUint32"); + } + vector.setSafe(rowIndex, (int) value); + } + + @Override + void writeDate32(int rowIndex, LocalDate date) { + if (type != PrimitiveType.Date32) { + throw error("writeDate32"); + } + vector.setSafe(rowIndex, (int) date.toEpochDay()); + } + + @Override + void writeDatetime(int rowIndex, LocalDateTime datetime) { + if (type != PrimitiveType.Datetime) { + throw error("writeDatetime"); + } + vector.setSafe(rowIndex, (int) datetime.toEpochSecond(ZoneOffset.UTC)); + } + } + + private static class BigIntColumn extends FixedWidthColumn { + + BigIntColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new BigIntVector(field, allocator)); + } + + @Override + void writeInt64(int rowIndex, long value) { + if (type != PrimitiveType.Int64) { + throw error("writeInt64"); + } + vector.setSafe(rowIndex, value); + } + + @Override + void writeUint64(int rowIndex, long value) { + if (type != PrimitiveType.Uint64) { + throw error("writeUint64"); + } + vector.setSafe(rowIndex, value); + } + + @Override + void writeDatetime64(int rowIndex, LocalDateTime datetime) { + if (type != PrimitiveType.Datetime64) { + throw error("writeDatetime64"); + } + vector.setSafe(rowIndex, (int) datetime.toEpochSecond(ZoneOffset.UTC)); + } + + @Override + void writeTimestamp(int rowIndex, Instant value) { + if (type != PrimitiveType.Timestamp) { + throw error("writeTimestamp"); + } + long micros = value.getEpochSecond() * 1000000L + value.getNano() / 1000; + vector.setSafe(rowIndex, micros); + } + + @Override + void writeTimestamp64(int rowIndex, Instant value) { + if (type != PrimitiveType.Timestamp64) { + throw error("writeTimestamp64"); + } + long micros = value.getEpochSecond() * 1000000L + value.getNano() / 1000; + vector.setSafe(rowIndex, micros); + } + + @Override + void writeInterval(int rowIndex, Duration duration) { + if (type != PrimitiveType.Interval) { + throw error("writeInterval"); + } + long micros = duration.getSeconds() * 1000000L + duration.toNanosPart() / 1000; + vector.setSafe(rowIndex, micros); + } + + @Override + void writeInterval64(int rowIndex, Duration duration) { + if (type != PrimitiveType.Interval64) { + throw error("writeInterval64"); + } + long micros = duration.getSeconds() * 1000000L + duration.toNanosPart() / 1000; + vector.setSafe(rowIndex, micros); + } + } + + private static class FloatColumn extends FixedWidthColumn { + + FloatColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new Float4Vector(field, allocator)); + } + + @Override + void writeFloat(int rowIndex, float value) { + vector.setSafe(rowIndex, value); + } + } + + private static class DoubleColumn extends FixedWidthColumn { + + DoubleColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new Float8Vector(field, allocator)); + } + + @Override + void writeDouble(int rowIndex, double value) { + vector.setSafe(rowIndex, value); + } + } + + private static class VarCharColumn extends VariableWidthColumn { + + VarCharColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new VarCharVector(field, allocator)); + } + + @Override + void writeText(int rowIndex, String text) { + if (type != PrimitiveType.Text) { + throw error("writeText"); + } + vector.setSafe(rowIndex, text.getBytes()); + } + + @Override + void writeJson(int rowIndex, String json) { + if (type != PrimitiveType.Json) { + throw error("writeJson"); + } + vector.setSafe(rowIndex, json.getBytes()); + } + + @Override + void writeJsonDocument(int rowIndex, String jsonDocument) { + if (type != PrimitiveType.JsonDocument) { + throw error("writeJsonDocument"); + } + vector.setSafe(rowIndex, jsonDocument.getBytes()); + } + } + + private static class VarBinaryColumn extends VariableWidthColumn { + + VarBinaryColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new VarBinaryVector(field, allocator)); + } + + @Override + void writeBytes(int rowIndex, byte[] bytes) { + if (type != PrimitiveType.Bytes) { + throw error("writeBytes"); + } + vector.setSafe(rowIndex, bytes); + } + + @Override + void writeYson(int rowIndex, byte[] yson) { + if (type != PrimitiveType.Yson) { + throw error("writeYson"); + } + vector.setSafe(rowIndex, yson); + } + } + + private static class FixedBinaryColumn extends FixedWidthColumn { + + FixedBinaryColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new FixedSizeBinaryVector(field, allocator)); + } + + /** + * @see ProtoValue#newUuid(java.util.UUID) + */ + @Override + void writeUuid(int rowIndex, UUID uuid) { + if (type != PrimitiveType.Uuid) { + throw error("writeUuid"); + } + + ByteBuffer buf = ByteBuffer.allocate(16); + + long msb = uuid.getMostSignificantBits(); + long timeLow = (msb & 0xffffffff00000000L) >>> 32; + long timeMid = (msb & 0x00000000ffff0000L) << 16; + long timeHighAndVersion = (msb & 0x000000000000ffffL) << 48; + buf.putLong(LittleEndian.bswap(timeLow | timeMid | timeHighAndVersion)); + buf.putLong(uuid.getLeastSignificantBits()); + + vector.setSafe(rowIndex, buf.array()); + } + + @Override + void writeDecimal(int rowIndex, DecimalValue value) { + if (type.getKind() != Type.Kind.DECIMAL) { + throw error("writeDecimal"); + } + + ByteBuffer buf = ByteBuffer.allocate(16); + buf.putLong(LittleEndian.bswap(value.getLow())); + buf.putLong(LittleEndian.bswap(value.getHigh())); + + vector.setSafe(rowIndex, buf.array()); + } + } + + private static class ColumnInfo { + private final String name; + private final Type type; + + ColumnInfo(String name, Type type) { + this.name = name; + this.type = type; + } + + private Column createVector(BufferAllocator allocator) { + if (type.getKind() == Type.Kind.OPTIONAL) { + return createColumnVector(allocator, type.unwrapOptional(), t -> Field.nullable(name, t)); + } + + return createColumnVector(allocator, type, t -> Field.notNullable(name, t)); + } + } + + private static Column createColumnVector(BufferAllocator allocator, Type type, Function gen) { + if (type.getKind() == Type.Kind.DECIMAL) { + return new FixedBinaryColumn(type, allocator, gen.apply(new ArrowType.FixedSizeBinary(16))); + } + + if (type.getKind() == Type.Kind.PRIMITIVE) { + switch ((PrimitiveType) type) { + case Bool: + return new TinyIntColumn(type, allocator, gen.apply(new ArrowType.Int(8, false))); + + case Int8: + return new TinyIntColumn(type, allocator, gen.apply(new ArrowType.Int(8, true))); + case Int16: + return new SmallIntColumn(type, allocator, gen.apply(new ArrowType.Int(16, true))); + case Int32: + return new IntColumn(type, allocator, gen.apply(new ArrowType.Int(32, true))); + case Int64: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true))); + + case Uint8: + return new TinyIntColumn(type, allocator, gen.apply(new ArrowType.Int(8, false))); + case Uint16: + return new SmallIntColumn(type, allocator, gen.apply(new ArrowType.Int(16, false))); + case Uint32: + return new IntColumn(type, allocator, gen.apply(new ArrowType.Int(32, false))); + case Uint64: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, false))); + + case Float: + return new FloatColumn(type, allocator, gen.apply( + new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + )); + case Double: + return new DoubleColumn(type, allocator, gen.apply( + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + )); + + case Text: + case Json: + case JsonDocument: + return new VarCharColumn(type, allocator, gen.apply(new ArrowType.Utf8())); + + case Bytes: + case Yson: + return new VarBinaryColumn(type, allocator, gen.apply(new ArrowType.Binary())); + + case Uuid: + return new FixedBinaryColumn(type, allocator, gen.apply(new ArrowType.FixedSizeBinary(16))); + + case Date: + return new SmallIntColumn(type, allocator, gen.apply(new ArrowType.Int(16, false))); + case Datetime: + return new IntColumn(type, allocator, gen.apply(new ArrowType.Int(32, false))); + case Timestamp: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true))); + case Interval: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Duration(TimeUnit.MILLISECOND))); + + case Date32: + return new IntColumn(type, allocator, gen.apply(new ArrowType.Int(32, true))); + case Datetime64: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true))); + case Timestamp64: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true))); + case Interval64: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true))); + + default: + break; + } + } + + throw new IllegalArgumentException("Type " + type + " is not supported in ArrowWriter"); + } + + public static Schema newSchema() { + return new Schema(); + } + + public static class Schema { + private final List columns = new ArrayList<>(); + + public Schema addColumn(String name, Type type) { + this.columns.add(new ColumnInfo(name, type)); + return this; + } + + public Schema addNullableColumn(String name, Type type) { + return addColumn(name, type.makeOptional()); + } + + public ApacheArrowWriter createWriter(BufferAllocator allocator) { + return new ApacheArrowWriter(allocator, columns); + } + } +} diff --git a/table/src/main/java/tech/ydb/table/utils/LittleEndian.java b/table/src/main/java/tech/ydb/table/utils/LittleEndian.java index b1c7fc56f..816d1ba1e 100644 --- a/table/src/main/java/tech/ydb/table/utils/LittleEndian.java +++ b/table/src/main/java/tech/ydb/table/utils/LittleEndian.java @@ -8,6 +8,9 @@ private LittleEndian() { } /** * Reverses the byte order of a long value. + * + * @param v long value + * @return reversed long value */ public static long bswap(long v) { return diff --git a/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java b/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java new file mode 100644 index 000000000..13dce25fb --- /dev/null +++ b/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java @@ -0,0 +1,348 @@ +package tech.ydb.table.query; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.channels.Channels; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.UUID; + +import com.google.protobuf.ByteString; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.function.ThrowingRunnable; + +import tech.ydb.table.values.DecimalType; +import tech.ydb.table.values.DecimalValue; +import tech.ydb.table.values.ListType; +import tech.ydb.table.values.PrimitiveType; + +/** + * + * @author Aleksandr Gorshenin + */ +public class ApacheArrowWriterTest { + + private void assertIllegalArgument(String message, ThrowingRunnable runnable) { + IllegalArgumentException ex = Assert.assertThrows(IllegalArgumentException.class, runnable); + Assert.assertEquals(message, ex.getMessage()); + } + + private void assertIllegalState(String message, ThrowingRunnable runnable) { + IllegalStateException ex = Assert.assertThrows(IllegalStateException.class, runnable); + Assert.assertEquals(message, ex.getMessage()); + } + + private static RootAllocator allocator; + + @BeforeClass + public static void initAllocator() { + allocator = new RootAllocator(); + } + + @AfterClass + public static void cleanAllocator() { + allocator.close(); + } + + @Test + public void unsupportedTypesTest() { + // tz date is not supported + assertIllegalArgument("Type TzDate is not supported in ArrowWriter", () -> ApacheArrowWriter.newSchema() + .addColumn("col", PrimitiveType.TzDate) + .createWriter(allocator)); + + // complex types are not supported (except Optional) + assertIllegalArgument("Type Int32? is not supported in ArrowWriter", () -> ApacheArrowWriter.newSchema() + .addNullableColumn("col", PrimitiveType.Int32.makeOptional()) + .createWriter(allocator)); + + // Non primitive types are not supported (except Decimal) + assertIllegalArgument("Type List is not supported in ArrowWriter", () -> ApacheArrowWriter.newSchema() + .addColumn("col", ListType.of(PrimitiveType.Int32)) + .createWriter(allocator)); + } + + @Test + public void invalidColumnTest() { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("col", PrimitiveType.Uuid) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + assertIllegalArgument("Column 'col2' not found", () -> row.writeUuid("col2", UUID.randomUUID())); + } + } + + @Test + public void nullableTypeTest() { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("col1", PrimitiveType.Uuid) + .addColumn("col2", PrimitiveType.Uuid.makeOptional()) + .addNullableColumn("col3", PrimitiveType.Uuid) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + assertIllegalState("cannot call writeNull, actual type: Uuid", () -> row.writeNull("col1")); + row.writeNull("col2"); // success + row.writeNull("col3"); // success + } + } + + @Test + public void baseTypeValidationTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("col1", PrimitiveType.Uuid) + .addColumn("col2", PrimitiveType.Int32) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeBool, actual type: Uuid", () -> row.writeBool("col1", true)); + + assertIllegalState("cannot call writeInt8, actual type: Uuid", () -> row.writeInt8("col1", (byte) 0)); + assertIllegalState("cannot call writeInt16, actual type: Uuid", () -> row.writeInt16("col1", (short) 0)); + assertIllegalState("cannot call writeInt32, actual type: Uuid", () -> row.writeInt32("col1", 0)); + assertIllegalState("cannot call writeInt64, actual type: Uuid", () -> row.writeInt64("col1", 0)); + + assertIllegalState("cannot call writeUint8, actual type: Uuid", () -> row.writeUint8("col1", 0)); + assertIllegalState("cannot call writeUint16, actual type: Uuid", () -> row.writeUint16("col1", 0)); + assertIllegalState("cannot call writeUint32, actual type: Uuid", () -> row.writeUint32("col1", 0)); + assertIllegalState("cannot call writeUint64, actual type: Uuid", () -> row.writeUint64("col1", 0)); + + assertIllegalState("cannot call writeFloat, actual type: Uuid", () -> row.writeFloat("col1", 0)); + assertIllegalState("cannot call writeDouble, actual type: Uuid", () -> row.writeDouble("col1", 0)); + + assertIllegalState("cannot call writeText, actual type: Uuid", () -> row.writeText("col1", "")); + assertIllegalState("cannot call writeJson, actual type: Uuid", () -> row.writeJson("col1", "")); + assertIllegalState("cannot call writeJsonDocument, actual type: Uuid", + () -> row.writeJsonDocument("col1", "")); + + assertIllegalState("cannot call writeBytes, actual type: Uuid", + () -> row.writeBytes("col1", new byte[0])); + assertIllegalState("cannot call writeYson, actual type: Uuid", + () -> row.writeYson("col1", new byte[0])); + + assertIllegalState("cannot call writeDate, actual type: Uuid", + () -> row.writeDate("col1", LocalDate.EPOCH)); + assertIllegalState("cannot call writeDatetime, actual type: Uuid", + () -> row.writeDatetime("col1", LocalDateTime.now())); + assertIllegalState("cannot call writeTimestamp, actual type: Uuid", + () -> row.writeTimestamp("col1", Instant.EPOCH)); + assertIllegalState("cannot call writeInterval, actual type: Uuid", + () -> row.writeInterval("col1", Duration.ZERO)); + + assertIllegalState("cannot call writeDate32, actual type: Uuid", + () -> row.writeDate32("col1", LocalDate.EPOCH)); + assertIllegalState("cannot call writeDatetime64, actual type: Uuid", + () -> row.writeDatetime64("col1", LocalDateTime.now())); + assertIllegalState("cannot call writeTimestamp64, actual type: Uuid", + () -> row.writeTimestamp64("col1", Instant.EPOCH)); + assertIllegalState("cannot call writeInterval64, actual type: Uuid", + () -> row.writeInterval64("col1", Duration.ZERO)); + + // second column + assertIllegalState("cannot call writeDecimal, actual type: Int32", + () -> row.writeDecimal("col2", DecimalType.getDefault().newValue(0))); + assertIllegalState("cannot call writeUuid, actual type: Int32", + () -> row.writeUuid("col2", UUID.randomUUID())); + } + } + + @Test + public void tinyIntVectorTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Bool) + .addNullableColumn("c2", PrimitiveType.Int8) + .addNullableColumn("c3", PrimitiveType.Uint8) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeInt8, actual type: Bool", () -> row.writeInt8("c1", (byte) 0)); + assertIllegalState("cannot call writeUint8, actual type: Int8", () -> row.writeUint8("c2", 0)); + assertIllegalState("cannot call writeBool, actual type: Uint8", () -> row.writeBool("c3", false)); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", + schema.toString()); + } + } + + @Test + public void smallIntVectorTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Date) + .addNullableColumn("c2", PrimitiveType.Int16) + .addNullableColumn("c3", PrimitiveType.Uint16) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeUint16, actual type: Date", () -> row.writeUint16("c1", 0)); + assertIllegalState("cannot call writeDate, actual type: Int16", () -> row.writeDate("c2", LocalDate.EPOCH)); + assertIllegalState("cannot call writeInt16, actual type: Uint16", () -> row.writeInt16("c3", (short) 0)); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", + schema.toString()); + } + } + + @Test + public void intVectorTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Int32) + .addColumn("c2", PrimitiveType.Uint32) + .addNullableColumn("c3", PrimitiveType.Date32) + .addNullableColumn("c4", PrimitiveType.Datetime) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeUint32, actual type: Int32", () -> row.writeUint32("c1", 0)); + assertIllegalState("cannot call writeDate32, actual type: Uint32", + () -> row.writeDate32("c2", LocalDate.EPOCH)); + assertIllegalState("cannot call writeDatetime, actual type: Date32", + () -> row.writeDatetime("c3", LocalDateTime.now())); + assertIllegalState("cannot call writeInt32, actual type: Datetime", () -> row.writeInt32("c4", 0)); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", schema.toString()); + } + } + + @Test + public void bigIntVectorTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Int64) + .addColumn("c2", PrimitiveType.Uint64) + .addNullableColumn("c3", PrimitiveType.Datetime64) + .addNullableColumn("c4", PrimitiveType.Timestamp) + .addNullableColumn("c5", PrimitiveType.Timestamp64) + .addNullableColumn("c6", PrimitiveType.Interval) + .addNullableColumn("c7", PrimitiveType.Interval64) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeUint64, actual type: Int64", () -> row.writeUint64("c1", 0)); + assertIllegalState("cannot call writeDatetime64, actual type: Uint64", + () -> row.writeDatetime64("c2", LocalDateTime.now())); + assertIllegalState("cannot call writeTimestamp, actual type: Datetime64", + () -> row.writeTimestamp("c3", Instant.now())); + assertIllegalState("cannot call writeTimestamp64, actual type: Timestamp", + () -> row.writeTimestamp64("c4", Instant.now())); + assertIllegalState("cannot call writeInterval, actual type: Timestamp64", + () -> row.writeInterval("c5", Duration.ZERO)); + assertIllegalState("cannot call writeInterval64, actual type: Interval", + () -> row.writeInterval64("c6", Duration.ZERO)); + assertIllegalState("cannot call writeInt64, actual type: Interval64", + () -> row.writeInt64("c7", 0)); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", + schema.toString()); + } + } + + @Test + public void varCharVectorTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Text) + .addNullableColumn("c2", PrimitiveType.Json) + .addColumn("c3", PrimitiveType.JsonDocument) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeJson, actual type: Text", () -> row.writeJson("c1", "")); + assertIllegalState("cannot call writeJsonDocument, actual type: Json", () -> row.writeJsonDocument("c2", "")); + assertIllegalState("cannot call writeText, actual type: JsonDocument", () -> row.writeText("c3", "")); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", schema.toString()); + } + } + + @Test + public void varBinaryVectorTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Bytes) + .addNullableColumn("c2", PrimitiveType.Yson) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeYson, actual type: Bytes", () -> row.writeYson("c1", new byte[0])); + assertIllegalState("cannot call writeBytes, actual type: Yson", () -> row.writeBytes("c2", new byte[0])); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", schema.toString()); + } + } + + @Test + public void fixedSizeBinaryVectorTest() throws IOException { + DecimalValue dv = DecimalType.getDefault().newValue(0); + UUID uv = UUID.randomUUID(); + + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Uuid) + .addNullableColumn("c2", DecimalType.getDefault()) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeDecimal, actual type: Uuid", () -> row.writeDecimal("c1", dv)); + assertIllegalState("cannot call writeUuid, actual type: Decimal(22, 9)", () -> row.writeUuid("c2", uv)); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", schema.toString()); + } + } + + private static Schema readApacheArrowSchema(ByteString bytes) { + try (InputStream is = bytes.newInput()) { + try (ReadChannel channel = new ReadChannel(Channels.newChannel(is))) { + return MessageSerializer.deserializeSchema(channel); + } + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } +} From d9e31c75e2e1a6c3bdf14c0978790e37eb2478fe Mon Sep 17 00:00:00 2001 From: Alexandr Gorshenin Date: Wed, 18 Feb 2026 12:05:49 +0000 Subject: [PATCH 4/8] Added tests for BulkUpsert with validation --- table/pom.xml | 14 + .../ydb/table/integration/AllTypesRecord.java | 437 ++++++++++++++++++ .../ydb/table/integration/BulkUpsertTest.java | 299 ++++++++++++ 3 files changed, 750 insertions(+) create mode 100644 table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java create mode 100644 table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java diff --git a/table/pom.xml b/table/pom.xml index 6cbb69458..e7775b279 100644 --- a/table/pom.xml +++ b/table/pom.xml @@ -61,9 +61,23 @@ true ydbplatform/local-ydb:trunk + enable_columnshard_bool + + + + jdk8-bootstrap + + [9 + + + + --add-opens=java.base/java.nio=ALL-UNNAMED + + + diff --git a/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java b/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java new file mode 100644 index 000000000..67f0d6dc7 --- /dev/null +++ b/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java @@ -0,0 +1,437 @@ +package tech.ydb.table.integration; + +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +import com.google.common.hash.Hashing; +import org.junit.Assert; + +import tech.ydb.table.description.TableColumn; +import tech.ydb.table.description.TableDescription; +import tech.ydb.table.query.ApacheArrowWriter; +import tech.ydb.table.result.ResultSetReader; +import tech.ydb.table.result.ValueReader; +import tech.ydb.table.values.DecimalType; +import tech.ydb.table.values.DecimalValue; +import tech.ydb.table.values.ListType; +import tech.ydb.table.values.ListValue; +import tech.ydb.table.values.PrimitiveType; +import tech.ydb.table.values.PrimitiveValue; +import tech.ydb.table.values.StructType; +import tech.ydb.table.values.StructValue; +import tech.ydb.table.values.Type; +import tech.ydb.table.values.Value; + +/** + * + * @author Aleksandr Gorshenin + */ +public class AllTypesRecord { + private static final DecimalType YDB_DECIMAL = DecimalType.getDefault(); + private static final DecimalType BANK_DECIMAL = DecimalType.of(31, 9); + private static final DecimalType BIG_DECIMAL = DecimalType.of(35, 0); + + // Keys + private final long id1; + private final int id2; + private final byte[] payload; + private final int length; + private final String hash ; + + // All types + private Byte v_int8; + private Short v_int16; + private Integer v_int32; + private Long v_int64; + + private Integer v_uint8; + private Integer v_uint16; + private Long v_uint32; + private Long v_uint64; + + private Boolean v_bool; + private Float v_float; + private Double v_double; + + private String v_text; + private String v_json; + private String v_jsdoc; + + private byte[] v_bytes; + private byte[] v_yson; + + private UUID v_uuid; + + private LocalDate v_date; + private LocalDateTime v_datetime; + private Instant v_timestamp; + private Duration v_interval; + + private LocalDate v_date32; + private LocalDateTime v_datetime64; + private Instant v_timestamp64; + private Duration v_interval64; + + private DecimalValue v_ydb_decimal; + private DecimalValue v_bank_decimal; + private DecimalValue v_big_decimal; + + private AllTypesRecord(long id1, int id2, byte[] payload) { + // Keys + this.id1 = id1; + this.id2 = id2; + this.length = payload.length; + this.payload = new byte[length]; + this.hash = Hashing.sha256().hashBytes(payload).toString(); + } + + private void assertValue(Set columns, String key, int idx, Function reader, + ResultSetReader rs, T value) { + if (!columns.contains(key)) { + return; + } + + String msg = "Row " + idx + " " + key; + Assert.assertEquals(msg, value != null, rs.getColumn(key).isOptionalItemPresent()); + + if (value != null) { + if (value instanceof byte[]) { + Assert.assertArrayEquals(msg, (byte[]) value, (byte[]) reader.apply(rs.getColumn(key))); + } else { + Assert.assertEquals(msg, value, reader.apply(rs.getColumn(key))); + } + } + } + + public void assertRow(Set columns, int idx, ResultSetReader rs) { + // keys are required + Assert.assertEquals("Row " + idx + " id1", id1, rs.getColumn("id1").getUint64()); + Assert.assertEquals("Row " + idx + " id2", id2, rs.getColumn("id2").getInt64()); + Assert.assertArrayEquals("Row " + idx + " payload", payload, rs.getColumn("payload").getBytes()); + Assert.assertEquals("Row " + idx + " length", length, rs.getColumn("length").getUint32()); + Assert.assertEquals("Row " + idx + " hash", hash, rs.getColumn("hash").getText()); + + // other columns may be skipped or be empty + assertValue(columns, "Int8", idx, ValueReader::getInt8, rs, v_int8); + assertValue(columns, "Int16", idx, ValueReader::getInt16, rs, v_int16); + assertValue(columns, "Int32", idx, ValueReader::getInt32, rs, v_int32); + assertValue(columns, "Int64", idx, ValueReader::getInt64, rs, v_int64); + + assertValue(columns, "Uint8", idx, ValueReader::getUint8, rs, v_uint8); + assertValue(columns, "Uint16", idx, ValueReader::getUint16, rs, v_uint16); + assertValue(columns, "Uint32", idx, ValueReader::getUint32, rs, v_uint32); + assertValue(columns, "Uint64", idx, ValueReader::getUint64, rs, v_uint64); + + assertValue(columns, "Bool", idx, ValueReader::getBool, rs, v_bool); + assertValue(columns, "Float", idx, ValueReader::getFloat, rs, v_float); + assertValue(columns, "Double", idx, ValueReader::getDouble, rs, v_double); + + assertValue(columns, "Text", idx, ValueReader::getText, rs, v_text); + assertValue(columns, "Json", idx, ValueReader::getJson, rs, v_json); + assertValue(columns, "JsonDocument", idx, ValueReader::getJsonDocument, rs, v_jsdoc); + + assertValue(columns, "Uuid", idx, ValueReader::getUuid, rs, v_uuid); + + assertValue(columns, "Bytes", idx, ValueReader::getBytes, rs, v_bytes); + assertValue(columns, "Yson", idx, ValueReader::getYson, rs, v_yson); + + assertValue(columns, "Date", idx, ValueReader::getDate, rs, v_date); + assertValue(columns, "Datetime", idx, ValueReader::getDatetime, rs, v_datetime); + assertValue(columns, "Timestamp", idx, ValueReader::getTimestamp, rs, v_timestamp); + assertValue(columns, "Interval", idx, ValueReader::getInterval, rs, v_interval); + + assertValue(columns, "Date32", idx, ValueReader::getDate32, rs, v_date32); + assertValue(columns, "Datetime64", idx, ValueReader::getDatetime64, rs, v_datetime64); + assertValue(columns, "Timestamp64", idx, ValueReader::getTimestamp64, rs, v_timestamp64); + assertValue(columns, "Interval64", idx, ValueReader::getInterval64, rs, v_interval64); + + assertValue(columns, "YdbDecimal", idx, ValueReader::getDecimal, rs, v_ydb_decimal); + assertValue(columns, "BankDecimal", idx, ValueReader::getDecimal, rs, v_bank_decimal); + assertValue(columns, "BigDecimal", idx, ValueReader::getDecimal, rs, v_big_decimal); + } + + private Value makePb(Type type, Function> func, T value) { + return value != null ? func.apply(value) : type.makeOptional().emptyValue(); + } + + private Map> toPb(List columnOnly) { + Map> struct = new HashMap<>(); + + BiConsumer> write = (key, value) -> { + if (columnOnly.contains(key)) { + struct.put(key, value); + } + }; + + // keys are required + struct.put("id1", PrimitiveValue.newUint64(id1)); + struct.put("id2", PrimitiveValue.newInt64(id2)); + struct.put("payload", PrimitiveValue.newBytesOwn(payload)); + struct.put("length", PrimitiveValue.newUint32(length)); + struct.put("hash", PrimitiveValue.newText(hash)); + + // other columns may be skipped or be empty + write.accept("Int8", makePb(PrimitiveType.Int8, PrimitiveValue::newInt8, v_int8)); + write.accept("Int16", makePb(PrimitiveType.Int16, PrimitiveValue::newInt16, v_int16)); + write.accept("Int32", makePb(PrimitiveType.Int32, PrimitiveValue::newInt32, v_int32)); + write.accept("Int64", makePb(PrimitiveType.Int64, PrimitiveValue::newInt64, v_int64)); + + write.accept("Uint8", makePb(PrimitiveType.Uint8, PrimitiveValue::newUint8, v_uint8)); + write.accept("Uint16", makePb(PrimitiveType.Uint16, PrimitiveValue::newUint16, v_uint16)); + write.accept("Uint32", makePb(PrimitiveType.Uint32, PrimitiveValue::newUint32, v_uint32)); + write.accept("Uint64", makePb(PrimitiveType.Uint64, PrimitiveValue::newUint64, v_uint64)); + + write.accept("Bool", makePb(PrimitiveType.Bool, PrimitiveValue::newBool, v_bool)); + write.accept("Float", makePb(PrimitiveType.Float, PrimitiveValue::newFloat, v_float)); + write.accept("Double", makePb(PrimitiveType.Double, PrimitiveValue::newDouble, v_double)); + + write.accept("Text", makePb(PrimitiveType.Text, PrimitiveValue::newText, v_text)); + write.accept("Json", makePb(PrimitiveType.Json, PrimitiveValue::newJson, v_json)); + write.accept("JsonDocument", makePb(PrimitiveType.JsonDocument, PrimitiveValue::newJsonDocument, v_jsdoc)); + + write.accept("Bytes", makePb(PrimitiveType.Bytes, PrimitiveValue::newBytes, v_bytes)); + write.accept("Yson", makePb(PrimitiveType.Yson, PrimitiveValue::newYson, v_yson)); + + write.accept("Uuid", makePb(PrimitiveType.Uuid, PrimitiveValue::newUuid, v_uuid)); + + write.accept("Date", makePb(PrimitiveType.Date, PrimitiveValue::newDate, v_date)); + write.accept("Datetime", makePb(PrimitiveType.Datetime, PrimitiveValue::newDatetime, v_datetime)); + write.accept("Timestamp", makePb(PrimitiveType.Timestamp, PrimitiveValue::newTimestamp, v_timestamp)); + write.accept("Interval", makePb(PrimitiveType.Interval, PrimitiveValue::newInterval, v_interval)); + + write.accept("Date32", makePb(PrimitiveType.Date32, PrimitiveValue::newDate32, v_date32)); + write.accept("Datetime64", makePb(PrimitiveType.Datetime64, PrimitiveValue::newDatetime64, v_datetime64)); + write.accept("Timestamp64", makePb(PrimitiveType.Timestamp64, PrimitiveValue::newTimestamp64, v_timestamp64)); + write.accept("Interval64", makePb(PrimitiveType.Interval64, PrimitiveValue::newInterval64, v_interval64)); + + write.accept("YdbDecimal", makePb(YDB_DECIMAL, Function.identity(), v_ydb_decimal)); + write.accept("BankDecimal", makePb(BANK_DECIMAL, Function.identity(), v_bank_decimal)); + write.accept("BigDecimal", makePb(BIG_DECIMAL, Function.identity(), v_big_decimal)); + + return struct; + } + + private void writeNullable(Set columns, String key, ApacheArrowWriter.Row row, + BiConsumer writer, T value) { + if (!columns.contains(key)) { + return; + } + + if (value == null) { + row.writeNull(key); + } else { + writer.accept(key, value); + } + } + + public void writeToApacheArrow(Set columnNames, ApacheArrowWriter.Row row) { + // keys are required + row.writeUint64("id1", id1); + row.writeInt64("id2", id2); + row.writeBytes("payload", payload); + row.writeUint32("length", length); + row.writeText("hash", hash); + + // other columns may be skipped or be empty + writeNullable(columnNames, "Int8", row, row::writeInt8, v_int8); + writeNullable(columnNames, "Int16", row, row::writeInt16, v_int16); + writeNullable(columnNames, "Int32", row, row::writeInt32, v_int32); + writeNullable(columnNames, "Int64", row, row::writeInt64, v_int64); + + writeNullable(columnNames, "Uint8", row, row::writeUint8, v_uint8); + writeNullable(columnNames, "Uint16", row, row::writeUint16, v_uint16); + writeNullable(columnNames, "Uint32", row, row::writeUint32, v_uint32); + writeNullable(columnNames, "Uint64", row, row::writeUint64, v_uint64); + + writeNullable(columnNames, "Bool", row, row::writeBool, v_bool); + + writeNullable(columnNames, "Float", row, row::writeFloat, v_float); + writeNullable(columnNames, "Double", row, row::writeDouble, v_double); + + writeNullable(columnNames, "Text", row, row::writeText, v_text); + writeNullable(columnNames, "Json", row, row::writeJson, v_json); + writeNullable(columnNames, "JsonDocument", row, row::writeJsonDocument, v_jsdoc); + + writeNullable(columnNames, "Bytes", row, row::writeBytes, v_bytes); + writeNullable(columnNames, "Yson", row, row::writeYson, v_yson); + + writeNullable(columnNames, "Uuid", row, row::writeUuid, v_uuid); + + writeNullable(columnNames, "Date", row, row::writeDate, v_date); + writeNullable(columnNames, "Datetime", row, row::writeDatetime, v_datetime); + writeNullable(columnNames, "Timestamp", row, row::writeTimestamp, v_timestamp); + writeNullable(columnNames, "Interval", row, row::writeInterval, v_interval); + + writeNullable(columnNames, "Date32", row, row::writeDate32, v_date32); + writeNullable(columnNames, "Datetime64", row, row::writeDatetime64, v_datetime64); + writeNullable(columnNames, "Timestamp64", row, row::writeTimestamp64, v_timestamp64); + writeNullable(columnNames, "Interval64", row, row::writeInterval64, v_interval64); + + writeNullable(columnNames, "YdbDecimal", row, row::writeDecimal, v_ydb_decimal); + writeNullable(columnNames, "BankDecimal", row, row::writeDecimal, v_bank_decimal); + writeNullable(columnNames, "BigDecimal", row, row::writeDecimal, v_big_decimal); + } + + public static DecimalValue randomDecimal(DecimalType type, Random rnd) { + int kind = rnd.nextInt(1000); + switch (kind) { + case 0: return type.getNegInf(); + case 499: return type.newValue(0); + case 999: return type.getInf(); + default: + break; + } + BigInteger unscaled = new BigInteger(1 + rnd.nextInt(type.getPrecision() * 3 - 1), rnd); + if (kind < 500) { + unscaled = unscaled.negate(); + } + return type.newValueUnscaled(unscaled); + } + + private static T nullable(Random rnd, T value) { + return rnd.nextInt(10) == 0 ? null : value; + } + + public static AllTypesRecord random(long id1, int id2, Random rnd) { + int length = 10 + rnd.nextInt(256); + byte[] payload = new byte[length]; + rnd.nextBytes(payload); + + AllTypesRecord r = new AllTypesRecord(id1, id2, payload); + + // All types + r.v_int8 = nullable(rnd, (byte) rnd.nextInt()); + r.v_int16 = nullable(rnd, (short) rnd.nextInt()); + r.v_int32 = nullable(rnd, rnd.nextInt()); + r.v_int64 = nullable(rnd, rnd.nextLong()); + + r.v_uint8 = nullable(rnd, rnd.nextInt() & 0xFF); + r.v_uint16 = nullable(rnd, rnd.nextInt() & 0xFFFF); + r.v_uint32 = nullable(rnd, rnd.nextLong() & 0xFFFFFFFFL); + r.v_uint64 = nullable(rnd, rnd.nextLong()); + + r.v_bool = nullable(rnd, rnd.nextBoolean()); + r.v_float = nullable(rnd, rnd.nextFloat()); + r.v_double = nullable(rnd, rnd.nextDouble()); + + r.v_text = nullable(rnd, "Text" + rnd.nextInt(1000)); + r.v_json = nullable(rnd, "{\"json\":" + rnd.nextInt(1000, 2000) + "}"); + r.v_jsdoc = nullable(rnd, "{\"document\":" + rnd.nextInt(2000, 3000) + "}"); + + r.v_bytes = nullable(rnd, ("Bytes " + rnd.nextInt(3000, 4000)).getBytes(StandardCharsets.UTF_8)); + r.v_yson = nullable(rnd, ("{yson=" + rnd.nextInt(5000, 6000) + "}").getBytes(StandardCharsets.UTF_8)); + + r.v_uuid = nullable(rnd, UUID.nameUUIDFromBytes(("UUID" + rnd.nextInt()).getBytes(StandardCharsets.UTF_8))); + + r.v_date = nullable(rnd, LocalDate.EPOCH.plusDays(rnd.nextInt(5000))); + r.v_datetime = nullable(rnd, LocalDateTime.ofEpochSecond(rnd.nextLong(1000000000), 0, ZoneOffset.ofHours(5))); + r.v_timestamp = nullable(rnd, Instant.ofEpochSecond(rnd.nextLong(2000000000L), rnd.nextLong(1000000) * 1000)); + r.v_interval = nullable(rnd, Duration.ofNanos((rnd.nextLong(10000000) - 5000000) * 1000)); + + r.v_date32 = nullable(rnd, LocalDate.EPOCH.plusDays(rnd.nextInt(5000) - 2500)); + r.v_datetime64 = nullable(rnd, LocalDateTime.ofEpochSecond(rnd.nextLong(1000000000) - 500000000, 0, + ZoneOffset.ofHours(5))); + r.v_timestamp64 = nullable(rnd, Instant.ofEpochSecond(rnd.nextLong(10000000000L) - 5000000000L, + rnd.nextLong(1000000) * 1000)); + r.v_interval64 = nullable(rnd, Duration.ofNanos((rnd.nextLong(10000000) - 5000000) * 1000)); + + r.v_ydb_decimal = nullable(rnd, randomDecimal(YDB_DECIMAL, rnd)); + r.v_bank_decimal = nullable(rnd, randomDecimal(BANK_DECIMAL, rnd)); + r.v_big_decimal = nullable(rnd, randomDecimal(BIG_DECIMAL, rnd)); + + return r; + } + + public static List randomBatch(int id1, int id2_start, int count) { + Random rnd = new Random(id1 * count + id2_start); + List batch = new ArrayList<>(count); + for (int idx = 0; idx < count; idx += 1) { + batch.add(random(id1, id2_start + idx, rnd)); + } + return batch; + } + + public static ListValue createProtobufBatch(TableDescription desc, List batch) { + List columnNames = desc.getColumns().stream().map(TableColumn::getName).collect(Collectors.toList()); + List columnTypes = desc.getColumns().stream().map(TableColumn::getType).collect(Collectors.toList()); + + StructType type = StructType.of(columnNames, columnTypes); + List values = batch.stream().map(r -> type.newValue(r.toPb(columnNames))) + .collect(Collectors.toList()); + + return ListType.of(type).newValue(values); + } + + public static TableDescription createTableDescription(boolean isColumnShard, boolean isApacheArrow) { + TableDescription.Builder builder = TableDescription.newBuilder() + .addNonnullColumn("id1", PrimitiveType.Uint64) + .addNonnullColumn("id2", PrimitiveType.Int64) + .addNonnullColumn("payload", PrimitiveType.Bytes) + .addNonnullColumn("length", PrimitiveType.Uint32) + .addNonnullColumn("hash", PrimitiveType.Text) + + .addNullableColumn("Int8", PrimitiveType.Int8) + .addNullableColumn("Int16", PrimitiveType.Int16) + .addNullableColumn("Int32", PrimitiveType.Int32) + .addNullableColumn("Int64", PrimitiveType.Int64) + .addNullableColumn("Uint8", PrimitiveType.Uint8) + .addNullableColumn("Uint16", PrimitiveType.Uint16) + .addNullableColumn("Uint32", PrimitiveType.Uint32) + .addNullableColumn("Uint64", PrimitiveType.Uint64) + .addNullableColumn("Bool", PrimitiveType.Bool) + .addNullableColumn("Float", PrimitiveType.Float) + .addNullableColumn("Double", PrimitiveType.Double) + .addNullableColumn("Text", PrimitiveType.Text) + .addNullableColumn("Json", PrimitiveType.Json) + .addNullableColumn("JsonDocument", PrimitiveType.JsonDocument) + .addNullableColumn("Bytes", PrimitiveType.Bytes) + .addNullableColumn("Yson", PrimitiveType.Yson); + + // https://github.com/ydb-platform/ydb/issues/13047 + if (!isColumnShard) { + builder = builder.addNullableColumn("Uuid", PrimitiveType.Uuid); + } + + builder = builder + .addNullableColumn("Date", PrimitiveType.Date) + .addNullableColumn("Datetime", PrimitiveType.Datetime) + .addNullableColumn("Timestamp", PrimitiveType.Timestamp); + + // https://github.com/ydb-platform/ydb/issues/13050 + if (!isColumnShard) { + builder = builder.addNullableColumn("Interval", PrimitiveType.Interval); + } + + builder = builder + .addNullableColumn("Date32", PrimitiveType.Date32) + .addNullableColumn("Datetime64", PrimitiveType.Datetime64) + .addNullableColumn("Timestamp64", PrimitiveType.Timestamp64) + .addNullableColumn("Interval64", PrimitiveType.Interval64) + .addNullableColumn("YdbDecimal", YDB_DECIMAL) + .addNullableColumn("BankDecimal", BANK_DECIMAL) + .addNullableColumn("BigDecimal", BIG_DECIMAL); + + if (isColumnShard) { + builder = builder.setPrimaryKey("hash").setStoreType(TableDescription.StoreType.COLUMN); + } else { + builder = builder.setPrimaryKeys("id1", "id2").setStoreType(TableDescription.StoreType.ROW); + } + + return builder.build(); + } +} diff --git a/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java b/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java new file mode 100644 index 000000000..769a68171 --- /dev/null +++ b/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java @@ -0,0 +1,299 @@ +package tech.ydb.table.integration; + +import java.io.IOException; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; +import java.util.stream.Collectors; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +import tech.ydb.core.grpc.GrpcReadStream; +import tech.ydb.table.SessionRetryContext; +import tech.ydb.table.description.TableColumn; +import tech.ydb.table.description.TableDescription; +import tech.ydb.table.impl.SimpleTableClient; +import tech.ydb.table.query.ApacheArrowWriter; +import tech.ydb.table.query.BulkUpsertData; +import tech.ydb.table.query.Params; +import tech.ydb.table.result.ResultSetReader; +import tech.ydb.table.rpc.grpc.GrpcTableRpc; +import tech.ydb.table.settings.ExecuteScanQuerySettings; +import tech.ydb.table.settings.ExecuteSchemeQuerySettings; +import tech.ydb.table.values.ListValue; +import tech.ydb.table.values.PrimitiveValue; +import tech.ydb.test.junit4.GrpcTransportRule; + +/** + * + * @author Aleksandr Gorshenin + */ +public class BulkUpsertTest { + @ClassRule + public static final GrpcTransportRule YDB = new GrpcTransportRule(); + private static final SimpleTableClient client = SimpleTableClient.newClient(GrpcTableRpc.useTransport(YDB)).build(); + private static final SessionRetryContext retryCtx = SessionRetryContext.create(client) + .idempotent(false) + .build(); + + private static final String TEST_TABLE = "arrow/test_table"; + private static final String DROP_TABLE_YQL = "DROP TABLE IF EXISTS `" + TEST_TABLE + "`;"; + private static final String SELECT_TABLE_YQL = "DECLARE $id1 AS Uint64; SELECT * FROM `" + TEST_TABLE + "` " + + "WHERE id1 = $id1 ORDER BY id2"; + + private static String tablePath() { + return YDB.getDatabase() + "/" + TEST_TABLE; + } + +// @After + @Before + public void cleanTable() { + retryCtx.supplyStatus(session -> session.executeSchemeQuery(DROP_TABLE_YQL, new ExecuteSchemeQuerySettings())) + .join().expectSuccess("cannot drop table"); + } + + private static void createTable(TableDescription table) { + retryCtx.supplyStatus(s -> s.createTable(tablePath(), table)) + .join().expectSuccess("Cannot create table"); + } + + private static void bulkUpsert(BulkUpsertData data) { + retryCtx.supplyStatus(session -> session.executeBulkUpsert(tablePath(), data)) + .join().expectSuccess("bulk upsert problem in table " + tablePath()); + } + + private static void bulkUpsert(ListValue rows) { + retryCtx.supplyStatus(session -> session.executeBulkUpsert(tablePath(), rows)) + .join().expectSuccess("bulk upsert problem in table " + tablePath()); + } + + private static int readTable(long id1, BiConsumer validator) { + AtomicInteger count = new AtomicInteger(); + + try { + retryCtx.supplyStatus(session -> { + count.set(0); + + GrpcReadStream stream = session.executeScanQuery(SELECT_TABLE_YQL, + Params.of("$id1", PrimitiveValue.newUint64(id1)), + ExecuteScanQuerySettings.newBuilder().build() + ); + + return stream.start((rs) -> { + while (rs.next()) { + int idx = count.getAndIncrement(); + validator.accept(idx, rs); + } + }); + }).join().expectSuccess("Cannot read table " + TEST_TABLE); + } catch (CompletionException ex) { + if (ex.getCause() instanceof AssertionError) { + throw (AssertionError) ex.getCause(); + } + throw ex; + } + + return count.get(); + } + + @Test + public void writeProtobufToDataShardTest() { + // Create table + TableDescription table = AllTypesRecord.createTableDescription(false, false); + createTable(table); + + Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet()); + + // Write & read batch of 1000 records with id1 = 1 + List batch1 = AllTypesRecord.randomBatch(1, 1, 1000); + bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch1)); + + int rows1Count = readTable(1, (idx, rs) -> { + Assert.assertTrue("Unexpected row index", idx < batch1.size()); + batch1.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(1000, rows1Count); + + // Write & read batch of 2000 records with id1 = 2 + List batch2 = AllTypesRecord.randomBatch(2, 1, 2000); + bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch2)); + + int rows2Count = readTable(2, (idx, rs) -> { + Assert.assertTrue("Unexpected row index", idx < batch2.size()); + batch2.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(2000, rows2Count); + } + + @Test + public void writeProtobufToColumnShardTable() { + // Create table + TableDescription table = AllTypesRecord.createTableDescription(true, false); + Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet()); + + createTable(table); + + // Write & read batch of 5000 records with id1 = 1 + List batch1 = AllTypesRecord.randomBatch(1, 1, 5000); + bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch1)); + // Read table and validate data + int rows1Count = readTable(1, (idx, rs) -> { + Assert.assertTrue("Unexpected row index", idx < batch1.size()); + batch1.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(5000, rows1Count); + + // Write & read batch of 10000 records with id1 = 2 + List batch2 = AllTypesRecord.randomBatch(2, 1, 10000); + bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch2)); + + int rows2Count = readTable(2, (idx, rs) -> { + Assert.assertTrue("Unexpected row index", idx < batch2.size()); + batch2.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(10000, rows2Count); + } + + @Test + public void writeApacheArrowToDataShardTest() { + // Create table + TableDescription table = AllTypesRecord.createTableDescription(false, true); + retryCtx.supplyStatus(s -> s.createTable(tablePath(), table)).join().expectSuccess("Cannot create table"); + + Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet()); + + ApacheArrowWriter.Schema schema = ApacheArrowWriter.newSchema(); + table.getColumns().forEach(column -> schema.addColumn(column.getName(), column.getType())); + + // Create batch of 1000 records + List batch1 = AllTypesRecord.randomBatch(1, 1, 1000); + // Create batch of 2000 records + List batch2 = AllTypesRecord.randomBatch(2, 1, 2000); + + try (BufferAllocator allocator = new RootAllocator()) { + try (ApacheArrowWriter writer = schema.createWriter(allocator)) { + // create batch with estimated size + ApacheArrowWriter.Batch data1 = writer.createNewBatch(1000); + batch1.forEach(r -> r.writeToApacheArrow(columnNames, data1.writeNextRow())); + bulkUpsert(data1.buildBatch()); + + // create batch without estimated size + ApacheArrowWriter.Batch data2 = writer.createNewBatch(0); + batch2.forEach(r -> r.writeToApacheArrow(columnNames, data2.writeNextRow())); + bulkUpsert(data2.buildBatch()); + + } catch (IOException ex) { + throw new AssertionError("Cannot serialize apache arrow", ex); + } + } + + int rows1Count = readTable(1, (idx, rs) -> { + Assert.assertTrue("Unexpected row index", idx < batch1.size()); + batch1.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(1000, rows1Count); + + int rows2Count = readTable(2, (idx, rs) -> { + Assert.assertTrue("Unexpected row index ", idx < batch2.size()); + batch2.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(2000, rows2Count); + } + + @Test + public void writeApacheArrowToColumnShardTest() { + // Create table + TableDescription table = AllTypesRecord.createTableDescription(true, true); + retryCtx.supplyStatus(s -> s.createTable(tablePath(), table)).join().expectSuccess("Cannot create table"); + + Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet()); + + ApacheArrowWriter.Schema schema = ApacheArrowWriter.newSchema(); + table.getColumns().forEach(column -> schema.addColumn(column.getName(), column.getType())); + + // Create batch of 5000 records + List batch1 = AllTypesRecord.randomBatch(1, 1, 5000); + // Create batch of 10000 records + List batch2 = AllTypesRecord.randomBatch(2, 1, 10000); + + try (BufferAllocator allocator = new RootAllocator()) { + try (ApacheArrowWriter writer = schema.createWriter(allocator)) { + // create batch without estimated size + ApacheArrowWriter.Batch data1 = writer.createNewBatch(0); + batch1.forEach(r -> r.writeToApacheArrow(columnNames, data1.writeNextRow())); + bulkUpsert(data1.buildBatch()); + + // create batch with estimated size + ApacheArrowWriter.Batch data2 = writer.createNewBatch(10000); + batch2.forEach(r -> r.writeToApacheArrow(columnNames, data2.writeNextRow())); + bulkUpsert(data2.buildBatch()); + + } catch (IOException ex) { + throw new AssertionError("Cannot serialize apache arrow", ex); + } + } + + int rows1Count = readTable(1, (idx, rs) -> { + Assert.assertTrue("Unexpected row index", idx < batch1.size()); + batch1.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(5000, rows1Count); + + int rows2Count = readTable(2, (idx, rs) -> { + Assert.assertTrue("Unexpected row index ", idx < batch2.size()); + batch2.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(10000, rows2Count); + } + +// +// private void assertApacheArrowBatch(ByteString schemaBytes, ByteString batchBytes, Iterator it) { +// try (BufferAllocator allocator = new RootAllocator()) { +// Schema schema = readApacheArrowSchema(schemaBytes); +// try (VectorSchemaRoot vector = VectorSchemaRoot.create(schema, allocator)) { +// try (InputStream is = batchBytes.newInput()) { +// try (ReadChannel channel = new ReadChannel(Channels.newChannel(is))) { +// try (ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(channel, allocator)) { +// VectorLoader loader = new VectorLoader(vector); +// loader.load(batch); +// } +// } +// } +// +// UInt8Vector id1 = (UInt8Vector) vector.getVector("id1"); +// BigIntVector id2 = (BigIntVector) vector.getVector("id2"); +// IntVector length = (IntVector) vector.getVector("length"); +// VarCharVector hash = (VarCharVector) vector.getVector("hash"); +// VarBinaryVector data = (VarBinaryVector) vector.getVector("data"); +// TimeStampMicroTZVector tm = (TimeStampMicroTZVector) vector.getVector("timestamp"); +// UInt2Vector date = (UInt2Vector) vector.getVector("date"); +// Float8Vector amount = (Float8Vector) vector.getVector("amount"); +// +// for (int idx = 0; idx < vector.getRowCount(); idx++) { +// Assert.assertTrue("Assert has no row " + idx, it.hasNext()); +// Record r = it.next(); +// +// Assert.assertEquals("Row " + idx + " fail", r.id1, id1.get(idx)); +// Assert.assertEquals("Row " + idx + " fail", r.id2, id2.get(idx)); +// Assert.assertEquals("Row " + idx + " fail", r.length, length.get(idx)); +// Assert.assertArrayEquals("Row " + idx + " fail", r.hash.getBytes(), hash.get(idx)); +// Assert.assertArrayEquals("Row " + idx + " fail", r.data, data.get(idx)); +// +// long rm = r.timestamp.getEpochSecond() * 1000000L + r.timestamp.getNano() / 1000; +// Assert.assertEquals("Row " + idx + " fail", rm, tm.get(idx)); +// Assert.assertEquals("Row " + idx + " fail", r.date, LocalDate.ofEpochDay(date.get(idx))); +// Assert.assertEquals("Row " + idx + " fail", r.amount, amount.get(idx), 1e-6); +// } +// } +// } catch (IOException ex) { +// throw new RuntimeException(ex); +// } +// } +} From 96c01e8d798dc014d3e6c7a8dae88d38d857c3d9 Mon Sep 17 00:00:00 2001 From: Alexandr Gorshenin Date: Wed, 18 Feb 2026 12:33:45 +0000 Subject: [PATCH 5/8] Fixed JDK8 compability --- pom.xml | 4 ++- .../ydb/table/query/ApacheArrowWriter.java | 4 +-- .../ydb/table/integration/AllTypesRecord.java | 31 ++++++++++--------- .../ydb/table/integration/BulkUpsertTest.java | 13 ++++---- .../table/query/ApacheArrowWriterTest.java | 13 ++++---- 5 files changed, 34 insertions(+), 31 deletions(-) diff --git a/pom.xml b/pom.xml index c7589ba9b..1f897e3cb 100644 --- a/pom.xml +++ b/pom.xml @@ -41,7 +41,9 @@ 2.8.9 5.11.0 1.19.3 - 18.3.0 + + + 17.0.0 diff --git a/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java b/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java index 072541917..bd7300254 100644 --- a/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java +++ b/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java @@ -622,7 +622,7 @@ void writeInterval(int rowIndex, Duration duration) { if (type != PrimitiveType.Interval) { throw error("writeInterval"); } - long micros = duration.getSeconds() * 1000000L + duration.toNanosPart() / 1000; + long micros = duration.getSeconds() * 1000000L + duration.getNano() / 1000; vector.setSafe(rowIndex, micros); } @@ -631,7 +631,7 @@ void writeInterval64(int rowIndex, Duration duration) { if (type != PrimitiveType.Interval64) { throw error("writeInterval64"); } - long micros = duration.getSeconds() * 1000000L + duration.toNanosPart() / 1000; + long micros = duration.getSeconds() * 1000000L + duration.getNano() / 1000; vector.setSafe(rowIndex, micros); } } diff --git a/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java b/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java index 67f0d6dc7..9752d1044 100644 --- a/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java +++ b/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java @@ -330,25 +330,26 @@ public static AllTypesRecord random(long id1, int id2, Random rnd) { r.v_double = nullable(rnd, rnd.nextDouble()); r.v_text = nullable(rnd, "Text" + rnd.nextInt(1000)); - r.v_json = nullable(rnd, "{\"json\":" + rnd.nextInt(1000, 2000) + "}"); - r.v_jsdoc = nullable(rnd, "{\"document\":" + rnd.nextInt(2000, 3000) + "}"); + r.v_json = nullable(rnd, "{\"json\":" + (1000 + rnd.nextInt(1000)) + "}"); + r.v_jsdoc = nullable(rnd, "{\"document\":" + (2000 + rnd.nextInt(1000)) + "}"); - r.v_bytes = nullable(rnd, ("Bytes " + rnd.nextInt(3000, 4000)).getBytes(StandardCharsets.UTF_8)); - r.v_yson = nullable(rnd, ("{yson=" + rnd.nextInt(5000, 6000) + "}").getBytes(StandardCharsets.UTF_8)); + r.v_bytes = nullable(rnd, ("Bytes " + (3000 + rnd.nextInt(1000))).getBytes(StandardCharsets.UTF_8)); + r.v_yson = nullable(rnd, ("{yson=" + (4000 + rnd.nextInt(1000)) + "}").getBytes(StandardCharsets.UTF_8)); r.v_uuid = nullable(rnd, UUID.nameUUIDFromBytes(("UUID" + rnd.nextInt()).getBytes(StandardCharsets.UTF_8))); - r.v_date = nullable(rnd, LocalDate.EPOCH.plusDays(rnd.nextInt(5000))); - r.v_datetime = nullable(rnd, LocalDateTime.ofEpochSecond(rnd.nextLong(1000000000), 0, ZoneOffset.ofHours(5))); - r.v_timestamp = nullable(rnd, Instant.ofEpochSecond(rnd.nextLong(2000000000L), rnd.nextLong(1000000) * 1000)); - r.v_interval = nullable(rnd, Duration.ofNanos((rnd.nextLong(10000000) - 5000000) * 1000)); - - r.v_date32 = nullable(rnd, LocalDate.EPOCH.plusDays(rnd.nextInt(5000) - 2500)); - r.v_datetime64 = nullable(rnd, LocalDateTime.ofEpochSecond(rnd.nextLong(1000000000) - 500000000, 0, - ZoneOffset.ofHours(5))); - r.v_timestamp64 = nullable(rnd, Instant.ofEpochSecond(rnd.nextLong(10000000000L) - 5000000000L, - rnd.nextLong(1000000) * 1000)); - r.v_interval64 = nullable(rnd, Duration.ofNanos((rnd.nextLong(10000000) - 5000000) * 1000)); + r.v_date = nullable(rnd, LocalDate.ofEpochDay(rnd.nextInt(5000))); + r.v_datetime = nullable(rnd, LocalDateTime.ofEpochSecond(0x60FFFFFF & rnd.nextLong(), 0, ZoneOffset.UTC)); + r.v_timestamp = nullable(rnd, Instant.ofEpochSecond(0x3FFFFFFFL & rnd.nextLong(), + rnd.nextInt(1000000) * 1000)); + r.v_interval = nullable(rnd, Duration.ofNanos((0x7FFFFFFFFFL & rnd.nextLong() - 0x4000000000L) * 1000)); + + r.v_date32 = nullable(rnd, LocalDate.ofEpochDay(rnd.nextInt(5000) - 2500)); + r.v_datetime64 = nullable(rnd, LocalDateTime.ofEpochSecond(0x60FFFFFF & rnd.nextLong() - 0x30FFFFFF, 0, + ZoneOffset.UTC)); + r.v_timestamp64 = nullable(rnd, Instant.ofEpochSecond(0x7FFFFFFFFFL & rnd.nextLong() - 0x4000000000L, + rnd.nextInt(1000000) * 1000)); + r.v_interval64 = nullable(rnd, Duration.ofNanos((0x7FFFFFFFFFL & rnd.nextLong() - 0x4000000000L) * 1000)); r.v_ydb_decimal = nullable(rnd, randomDecimal(YDB_DECIMAL, rnd)); r.v_bank_decimal = nullable(rnd, randomDecimal(BANK_DECIMAL, rnd)); diff --git a/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java b/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java index 769a68171..83b4fe59a 100644 --- a/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java +++ b/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java @@ -10,8 +10,8 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; +import org.junit.After; import org.junit.Assert; -import org.junit.Before; import org.junit.ClassRule; import org.junit.Test; @@ -52,8 +52,7 @@ private static String tablePath() { return YDB.getDatabase() + "/" + TEST_TABLE; } -// @After - @Before + @After public void cleanTable() { retryCtx.supplyStatus(session -> session.executeSchemeQuery(DROP_TABLE_YQL, new ExecuteSchemeQuerySettings())) .join().expectSuccess("cannot drop table"); @@ -141,24 +140,24 @@ public void writeProtobufToColumnShardTable() { createTable(table); // Write & read batch of 5000 records with id1 = 1 - List batch1 = AllTypesRecord.randomBatch(1, 1, 5000); + List batch1 = AllTypesRecord.randomBatch(1, 1, 2500); bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch1)); // Read table and validate data int rows1Count = readTable(1, (idx, rs) -> { Assert.assertTrue("Unexpected row index", idx < batch1.size()); batch1.get(idx).assertRow(columnNames, idx, rs); }); - Assert.assertEquals(5000, rows1Count); + Assert.assertEquals(2500, rows1Count); // Write & read batch of 10000 records with id1 = 2 - List batch2 = AllTypesRecord.randomBatch(2, 1, 10000); + List batch2 = AllTypesRecord.randomBatch(2, 1, 5000); bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch2)); int rows2Count = readTable(2, (idx, rs) -> { Assert.assertTrue("Unexpected row index", idx < batch2.size()); batch2.get(idx).assertRow(columnNames, idx, rs); }); - Assert.assertEquals(10000, rows2Count); + Assert.assertEquals(5000, rows2Count); } @Test diff --git a/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java b/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java index 13dce25fb..687c023be 100644 --- a/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java +++ b/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java @@ -135,20 +135,20 @@ public void baseTypeValidationTest() throws IOException { () -> row.writeYson("col1", new byte[0])); assertIllegalState("cannot call writeDate, actual type: Uuid", - () -> row.writeDate("col1", LocalDate.EPOCH)); + () -> row.writeDate("col1", LocalDate.ofEpochDay(0))); assertIllegalState("cannot call writeDatetime, actual type: Uuid", () -> row.writeDatetime("col1", LocalDateTime.now())); assertIllegalState("cannot call writeTimestamp, actual type: Uuid", - () -> row.writeTimestamp("col1", Instant.EPOCH)); + () -> row.writeTimestamp("col1", Instant.ofEpochSecond(0))); assertIllegalState("cannot call writeInterval, actual type: Uuid", () -> row.writeInterval("col1", Duration.ZERO)); assertIllegalState("cannot call writeDate32, actual type: Uuid", - () -> row.writeDate32("col1", LocalDate.EPOCH)); + () -> row.writeDate32("col1", LocalDate.ofEpochDay(0))); assertIllegalState("cannot call writeDatetime64, actual type: Uuid", () -> row.writeDatetime64("col1", LocalDateTime.now())); assertIllegalState("cannot call writeTimestamp64, actual type: Uuid", - () -> row.writeTimestamp64("col1", Instant.EPOCH)); + () -> row.writeTimestamp64("col1", Instant.ofEpochSecond(0))); assertIllegalState("cannot call writeInterval64, actual type: Uuid", () -> row.writeInterval64("col1", Duration.ZERO)); @@ -195,7 +195,8 @@ public void smallIntVectorTest() throws IOException { ApacheArrowWriter.Row row = batch.writeNextRow(); assertIllegalState("cannot call writeUint16, actual type: Date", () -> row.writeUint16("c1", 0)); - assertIllegalState("cannot call writeDate, actual type: Int16", () -> row.writeDate("c2", LocalDate.EPOCH)); + assertIllegalState("cannot call writeDate, actual type: Int16", + () -> row.writeDate("c2", LocalDate.ofEpochDay(0))); assertIllegalState("cannot call writeInt16, actual type: Uint16", () -> row.writeInt16("c3", (short) 0)); BulkUpsertArrowData data = batch.buildBatch(); @@ -220,7 +221,7 @@ public void intVectorTest() throws IOException { assertIllegalState("cannot call writeUint32, actual type: Int32", () -> row.writeUint32("c1", 0)); assertIllegalState("cannot call writeDate32, actual type: Uint32", - () -> row.writeDate32("c2", LocalDate.EPOCH)); + () -> row.writeDate32("c2", LocalDate.ofEpochDay(0))); assertIllegalState("cannot call writeDatetime, actual type: Date32", () -> row.writeDatetime("c3", LocalDateTime.now())); assertIllegalState("cannot call writeInt32, actual type: Datetime", () -> row.writeInt32("c4", 0)); From 81ccc5fd8c0672852d4e8e58f5cd6d6820fa215b Mon Sep 17 00:00:00 2001 From: Alexandr Gorshenin Date: Wed, 18 Feb 2026 13:05:45 +0000 Subject: [PATCH 6/8] Fixes by Copilot --- bom/pom.xml | 2 +- pom.xml | 2 +- table/pom.xml | 2 +- .../ydb/table/query/ApacheArrowWriter.java | 15 +++++------- .../ydb/table/integration/AllTypesRecord.java | 4 ++-- .../ydb/table/integration/BulkUpsertTest.java | 24 +++++++++---------- .../table/query/ApacheArrowWriterTest.java | 2 +- 7 files changed, 24 insertions(+), 27 deletions(-) diff --git a/bom/pom.xml b/bom/pom.xml index b453538f8..82db38b38 100644 --- a/bom/pom.xml +++ b/bom/pom.xml @@ -176,7 +176,7 @@ jdk8-bootstrap - [9 + [9,) diff --git a/pom.xml b/pom.xml index 1f897e3cb..fceb9f333 100644 --- a/pom.xml +++ b/pom.xml @@ -251,7 +251,7 @@ jdk8-bootstrap - [9 + [9,) argLine diff --git a/table/pom.xml b/table/pom.xml index e7775b279..88e249d7f 100644 --- a/table/pom.xml +++ b/table/pom.xml @@ -72,7 +72,7 @@ jdk8-bootstrap - [9 + [9,) diff --git a/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java b/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java index bd7300254..2fa6b4125 100644 --- a/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java +++ b/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java @@ -3,6 +3,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.Channels; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; import java.time.LocalDate; @@ -121,10 +122,6 @@ public Batch createNewBatch(int estimatedRowsCount) { @Override public void close() { vsr.close(); - - for (FieldVector field: vsr.getFieldVectors()) { - field.close(); - } } private class BatchImpl implements Batch { @@ -596,7 +593,7 @@ void writeDatetime64(int rowIndex, LocalDateTime datetime) { if (type != PrimitiveType.Datetime64) { throw error("writeDatetime64"); } - vector.setSafe(rowIndex, (int) datetime.toEpochSecond(ZoneOffset.UTC)); + vector.setSafe(rowIndex, datetime.toEpochSecond(ZoneOffset.UTC)); } @Override @@ -671,7 +668,7 @@ void writeText(int rowIndex, String text) { if (type != PrimitiveType.Text) { throw error("writeText"); } - vector.setSafe(rowIndex, text.getBytes()); + vector.setSafe(rowIndex, text.getBytes(StandardCharsets.UTF_8)); } @Override @@ -679,7 +676,7 @@ void writeJson(int rowIndex, String json) { if (type != PrimitiveType.Json) { throw error("writeJson"); } - vector.setSafe(rowIndex, json.getBytes()); + vector.setSafe(rowIndex, json.getBytes(StandardCharsets.UTF_8)); } @Override @@ -687,7 +684,7 @@ void writeJsonDocument(int rowIndex, String jsonDocument) { if (type != PrimitiveType.JsonDocument) { throw error("writeJsonDocument"); } - vector.setSafe(rowIndex, jsonDocument.getBytes()); + vector.setSafe(rowIndex, jsonDocument.getBytes(StandardCharsets.UTF_8)); } } @@ -829,7 +826,7 @@ private static Column createColumnVector(BufferAllocator allocator, Type type case Timestamp: return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true))); case Interval: - return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Duration(TimeUnit.MILLISECOND))); + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Duration(TimeUnit.MICROSECOND))); case Date32: return new IntColumn(type, allocator, gen.apply(new ArrowType.Int(32, true))); diff --git a/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java b/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java index 9752d1044..19254ccdc 100644 --- a/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java +++ b/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java @@ -96,7 +96,7 @@ private AllTypesRecord(long id1, int id2, byte[] payload) { this.id1 = id1; this.id2 = id2; this.length = payload.length; - this.payload = new byte[length]; + this.payload = payload; this.hash = Hashing.sha256().hashBytes(payload).toString(); } @@ -378,7 +378,7 @@ public static ListValue createProtobufBatch(TableDescription desc, List vali @Test public void writeProtobufToDataShardTest() { // Create table - TableDescription table = AllTypesRecord.createTableDescription(false, false); + TableDescription table = AllTypesRecord.createTableDescription(false); createTable(table); Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet()); @@ -134,12 +134,12 @@ public void writeProtobufToDataShardTest() { @Test public void writeProtobufToColumnShardTable() { // Create table - TableDescription table = AllTypesRecord.createTableDescription(true, false); + TableDescription table = AllTypesRecord.createTableDescription(true); Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet()); createTable(table); - // Write & read batch of 5000 records with id1 = 1 + // Write & read batch of 2500 records with id1 = 1 List batch1 = AllTypesRecord.randomBatch(1, 1, 2500); bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch1)); // Read table and validate data @@ -149,7 +149,7 @@ public void writeProtobufToColumnShardTable() { }); Assert.assertEquals(2500, rows1Count); - // Write & read batch of 10000 records with id1 = 2 + // Write & read batch of 5000 records with id1 = 2 List batch2 = AllTypesRecord.randomBatch(2, 1, 5000); bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch2)); @@ -163,7 +163,7 @@ public void writeProtobufToColumnShardTable() { @Test public void writeApacheArrowToDataShardTest() { // Create table - TableDescription table = AllTypesRecord.createTableDescription(false, true); + TableDescription table = AllTypesRecord.createTableDescription(false); retryCtx.supplyStatus(s -> s.createTable(tablePath(), table)).join().expectSuccess("Cannot create table"); Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet()); @@ -209,7 +209,7 @@ public void writeApacheArrowToDataShardTest() { @Test public void writeApacheArrowToColumnShardTest() { // Create table - TableDescription table = AllTypesRecord.createTableDescription(true, true); + TableDescription table = AllTypesRecord.createTableDescription(true); retryCtx.supplyStatus(s -> s.createTable(tablePath(), table)).join().expectSuccess("Cannot create table"); Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet()); @@ -217,10 +217,10 @@ public void writeApacheArrowToColumnShardTest() { ApacheArrowWriter.Schema schema = ApacheArrowWriter.newSchema(); table.getColumns().forEach(column -> schema.addColumn(column.getName(), column.getType())); + // Create batch of 2500 records + List batch1 = AllTypesRecord.randomBatch(1, 1, 2500); // Create batch of 5000 records - List batch1 = AllTypesRecord.randomBatch(1, 1, 5000); - // Create batch of 10000 records - List batch2 = AllTypesRecord.randomBatch(2, 1, 10000); + List batch2 = AllTypesRecord.randomBatch(2, 1, 5000); try (BufferAllocator allocator = new RootAllocator()) { try (ApacheArrowWriter writer = schema.createWriter(allocator)) { @@ -230,7 +230,7 @@ public void writeApacheArrowToColumnShardTest() { bulkUpsert(data1.buildBatch()); // create batch with estimated size - ApacheArrowWriter.Batch data2 = writer.createNewBatch(10000); + ApacheArrowWriter.Batch data2 = writer.createNewBatch(5000); batch2.forEach(r -> r.writeToApacheArrow(columnNames, data2.writeNextRow())); bulkUpsert(data2.buildBatch()); @@ -243,13 +243,13 @@ public void writeApacheArrowToColumnShardTest() { Assert.assertTrue("Unexpected row index", idx < batch1.size()); batch1.get(idx).assertRow(columnNames, idx, rs); }); - Assert.assertEquals(5000, rows1Count); + Assert.assertEquals(2500, rows1Count); int rows2Count = readTable(2, (idx, rs) -> { Assert.assertTrue("Unexpected row index ", idx < batch2.size()); batch2.get(idx).assertRow(columnNames, idx, rs); }); - Assert.assertEquals(10000, rows2Count); + Assert.assertEquals(5000, rows2Count); } // diff --git a/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java b/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java index 687c023be..4c6a60f28 100644 --- a/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java +++ b/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java @@ -267,7 +267,7 @@ public void bigIntVectorTest() throws IOException { Schema schema = readApacheArrowSchema(data.getSchema()); Assert.assertEquals("Schema", + + "c4: Int(64, true), c5: Int(64, true), c6: Duration(MICROSECOND), c7: Int(64, true)>", schema.toString()); } } From 7c18140f355c7fe70e39d63f7c979768e71902b4 Mon Sep 17 00:00:00 2001 From: Alexandr Gorshenin Date: Wed, 18 Feb 2026 13:23:37 +0000 Subject: [PATCH 7/8] Added small test for data validation --- .../ydb/table/integration/BulkUpsertTest.java | 44 --------------- .../table/query/ApacheArrowWriterTest.java | 53 +++++++++++++++++++ 2 files changed, 53 insertions(+), 44 deletions(-) diff --git a/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java b/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java index 5971abea4..b67d990aa 100644 --- a/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java +++ b/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java @@ -251,48 +251,4 @@ public void writeApacheArrowToColumnShardTest() { }); Assert.assertEquals(5000, rows2Count); } - -// -// private void assertApacheArrowBatch(ByteString schemaBytes, ByteString batchBytes, Iterator it) { -// try (BufferAllocator allocator = new RootAllocator()) { -// Schema schema = readApacheArrowSchema(schemaBytes); -// try (VectorSchemaRoot vector = VectorSchemaRoot.create(schema, allocator)) { -// try (InputStream is = batchBytes.newInput()) { -// try (ReadChannel channel = new ReadChannel(Channels.newChannel(is))) { -// try (ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(channel, allocator)) { -// VectorLoader loader = new VectorLoader(vector); -// loader.load(batch); -// } -// } -// } -// -// UInt8Vector id1 = (UInt8Vector) vector.getVector("id1"); -// BigIntVector id2 = (BigIntVector) vector.getVector("id2"); -// IntVector length = (IntVector) vector.getVector("length"); -// VarCharVector hash = (VarCharVector) vector.getVector("hash"); -// VarBinaryVector data = (VarBinaryVector) vector.getVector("data"); -// TimeStampMicroTZVector tm = (TimeStampMicroTZVector) vector.getVector("timestamp"); -// UInt2Vector date = (UInt2Vector) vector.getVector("date"); -// Float8Vector amount = (Float8Vector) vector.getVector("amount"); -// -// for (int idx = 0; idx < vector.getRowCount(); idx++) { -// Assert.assertTrue("Assert has no row " + idx, it.hasNext()); -// Record r = it.next(); -// -// Assert.assertEquals("Row " + idx + " fail", r.id1, id1.get(idx)); -// Assert.assertEquals("Row " + idx + " fail", r.id2, id2.get(idx)); -// Assert.assertEquals("Row " + idx + " fail", r.length, length.get(idx)); -// Assert.assertArrayEquals("Row " + idx + " fail", r.hash.getBytes(), hash.get(idx)); -// Assert.assertArrayEquals("Row " + idx + " fail", r.data, data.get(idx)); -// -// long rm = r.timestamp.getEpochSecond() * 1000000L + r.timestamp.getNano() / 1000; -// Assert.assertEquals("Row " + idx + " fail", rm, tm.get(idx)); -// Assert.assertEquals("Row " + idx + " fail", r.date, LocalDate.ofEpochDay(date.get(idx))); -// Assert.assertEquals("Row " + idx + " fail", r.amount, amount.get(idx), 1e-6); -// } -// } -// } catch (IOException ex) { -// throw new RuntimeException(ex); -// } -// } } diff --git a/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java b/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java index 4c6a60f28..4f48787cb 100644 --- a/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java +++ b/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java @@ -3,6 +3,7 @@ import java.io.IOException; import java.io.InputStream; import java.nio.channels.Channels; +import java.nio.charset.StandardCharsets; import java.time.Duration; import java.time.Instant; import java.time.LocalDate; @@ -11,7 +12,12 @@ import com.google.protobuf.ByteString; import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.AfterClass; @@ -337,6 +343,53 @@ public void fixedSizeBinaryVectorTest() throws IOException { } } + private BulkUpsertArrowData createSimpleBatch() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("pk", PrimitiveType.Int32) + .addNullableColumn("value", PrimitiveType.Text) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(2); + ApacheArrowWriter.Row row1 = batch.writeNextRow(); + ApacheArrowWriter.Row row2 = batch.writeNextRow(); + + row1.writeInt32("pk", 1); + row1.writeText("value", "value-1"); + + row2.writeInt32("pk", 2); + row2.writeText("value", "значение-2"); + + return batch.buildBatch(); + } + } + + @Test + public void readArrayBatchTest() throws IOException { + BulkUpsertArrowData data = createSimpleBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + try (VectorSchemaRoot vector = VectorSchemaRoot.create(schema, allocator)) { + // readApacheArrowBatch + try (InputStream is = data.getData().newInput()) { + try (ReadChannel channel = new ReadChannel(Channels.newChannel(is))) { + try (ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(channel, allocator)) { + VectorLoader loader = new VectorLoader(vector); + loader.load(batch); + } + } + } + + IntVector pk = (IntVector) vector.getVector("pk"); + VarCharVector value = (VarCharVector) vector.getVector("value"); + + Assert.assertEquals(2, vector.getRowCount()); + Assert.assertEquals(1, pk.get(0)); + Assert.assertEquals(2, pk.get(1)); + Assert.assertEquals("value-1", new String(value.get(0), StandardCharsets.UTF_8)); + Assert.assertEquals("значение-2", new String(value.get(1), StandardCharsets.UTF_8)); + } + } + private static Schema readApacheArrowSchema(ByteString bytes) { try (InputStream is = bytes.newInput()) { try (ReadChannel channel = new ReadChannel(Channels.newChannel(is))) { From 8624d3e3b1cd87077036e7d215a30c33fe848b13 Mon Sep 17 00:00:00 2001 From: Alexandr Gorshenin Date: Thu, 19 Feb 2026 15:01:28 +0000 Subject: [PATCH 8/8] Removed BulkUpsertCsvData --- .../ydb/table/query/BulkUpsertCsvData.java | 23 ------------------- 1 file changed, 23 deletions(-) delete mode 100644 table/src/main/java/tech/ydb/table/query/BulkUpsertCsvData.java diff --git a/table/src/main/java/tech/ydb/table/query/BulkUpsertCsvData.java b/table/src/main/java/tech/ydb/table/query/BulkUpsertCsvData.java deleted file mode 100644 index bb9c327fd..000000000 --- a/table/src/main/java/tech/ydb/table/query/BulkUpsertCsvData.java +++ /dev/null @@ -1,23 +0,0 @@ -package tech.ydb.table.query; - -import com.google.protobuf.ByteString; - -import tech.ydb.proto.formats.YdbFormats; -import tech.ydb.proto.table.YdbTable; - -/** - * - * @author Aleksandr Gorshenin - */ -public class BulkUpsertCsvData implements BulkUpsertData { - private final ByteString data; - - public BulkUpsertCsvData(ByteString data) { - this.data = data; - } - - @Override - public void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder) { - builder.setCsvSettings(YdbFormats.CsvSettings.newBuilder().build()).setData(data); - } -}