diff --git a/pom.xml b/pom.xml
index 9207b63dc..f47a2634c 100644
--- a/pom.xml
+++ b/pom.xml
@@ -40,8 +40,9 @@
2.25.3
2.13.2
5.21.0
-
2.0.3
+
+ 17.0.0
@@ -88,6 +89,17 @@
${gson.version}
+
+ org.apache.arrow
+ arrow-vector
+ ${apache.arrow.version}
+
+
+ org.apache.arrow
+ arrow-memory-netty
+ ${apache.arrow.version}
+
+
junit
@@ -238,11 +250,14 @@
jdk8-bootstrap
- [9
+ [9,)
argLine
+
+
+
diff --git a/table/pom.xml b/table/pom.xml
index 4c5dc1f58..a74e625d1 100644
--- a/table/pom.xml
+++ b/table/pom.xml
@@ -30,17 +30,26 @@
+ org.apache.arrow
+ arrow-vector
+ true
+
+
+ org.apache.arrow
+ arrow-memory-netty
+ true
+
+
+
junit
junit
test
-
tech.ydb.test
ydb-junit4-support
test
-
org.apache.logging.log4j
log4j-slf4j2-impl
@@ -57,9 +66,23 @@
true
ydbplatform/local-ydb:trunk
+ enable_columnshard_bool
+
+
+
+ jdk8-bootstrap
+
+ [9,)
+
+
+
+ --add-opens=java.base/java.nio=ALL-UNNAMED
+
+
+
diff --git a/table/src/main/java/tech/ydb/table/Session.java b/table/src/main/java/tech/ydb/table/Session.java
index 7793c972e..d4da4e7a1 100644
--- a/table/src/main/java/tech/ydb/table/Session.java
+++ b/table/src/main/java/tech/ydb/table/Session.java
@@ -199,7 +199,7 @@ default CompletableFuture executeScanQuery(String query, Params params,
CompletableFuture> keepAlive(KeepAliveSessionSettings settings);
default CompletableFuture executeBulkUpsert(String tablePath, ListValue rows, BulkUpsertSettings settings) {
- return executeBulkUpsert(tablePath, new BulkUpsertData(rows), settings);
+ return executeBulkUpsert(tablePath, BulkUpsertData.fromRows(rows), settings);
}
CompletableFuture executeBulkUpsert(String tablePath, BulkUpsertData data, BulkUpsertSettings settings);
diff --git a/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java b/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java
new file mode 100644
index 000000000..2fa6b4125
--- /dev/null
+++ b/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java
@@ -0,0 +1,868 @@
+package tech.ydb.table.query;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.Channels;
+import java.nio.charset.StandardCharsets;
+import java.time.Duration;
+import java.time.Instant;
+import java.time.LocalDate;
+import java.time.LocalDateTime;
+import java.time.ZoneOffset;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.function.Function;
+
+import com.google.protobuf.ByteString;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.BaseFixedWidthVector;
+import org.apache.arrow.vector.BaseVariableWidthVector;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.FixedSizeBinaryVector;
+import org.apache.arrow.vector.Float4Vector;
+import org.apache.arrow.vector.Float8Vector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.SmallIntVector;
+import org.apache.arrow.vector.TinyIntVector;
+import org.apache.arrow.vector.VarBinaryVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.ipc.WriteChannel;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.ipc.message.MessageSerializer;
+import org.apache.arrow.vector.types.FloatingPointPrecision;
+import org.apache.arrow.vector.types.TimeUnit;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+
+import tech.ydb.table.utils.LittleEndian;
+import tech.ydb.table.values.DecimalValue;
+import tech.ydb.table.values.PrimitiveType;
+import tech.ydb.table.values.Type;
+import tech.ydb.table.values.proto.ProtoValue;
+
+/**
+ *
+ * @author Aleksandr Gorshenin
+ */
+public class ApacheArrowWriter implements AutoCloseable {
+ public interface Batch {
+ Row writeNextRow();
+ BulkUpsertArrowData buildBatch() throws IOException;
+ }
+
+ public interface Row {
+ void writeNull(String column);
+
+ void writeBool(String column, boolean value);
+
+ void writeInt8(String column, byte value);
+ void writeInt16(String column, short value);
+ void writeInt32(String column, int value);
+ void writeInt64(String column, long value);
+
+ void writeUint8(String column, int value);
+ void writeUint16(String column, int value);
+ void writeUint32(String column, long value);
+ void writeUint64(String column, long value);
+
+ void writeFloat(String column, float value);
+ void writeDouble(String column, double value);
+
+ void writeText(String column, String text);
+ void writeJson(String column, String json);
+ void writeJsonDocument(String column, String jsonDocument);
+
+ void writeBytes(String column, byte[] bytes);
+ void writeYson(String column, byte[] yson);
+
+ void writeUuid(String column, UUID uuid);
+
+ void writeDate(String column, LocalDate date);
+ void writeDatetime(String column, LocalDateTime datetime);
+ void writeTimestamp(String column, Instant instant);
+ void writeInterval(String column, Duration interval);
+
+ void writeDate32(String column, LocalDate date32);
+ void writeDatetime64(String column, LocalDateTime datetime64);
+ void writeTimestamp64(String column, Instant instant64);
+ void writeInterval64(String column, Duration interval64);
+
+ void writeDecimal(String column, DecimalValue value);
+ }
+
+ private final VectorSchemaRoot vsr;
+ private final Map> columns = new HashMap<>();
+
+ private ApacheArrowWriter(BufferAllocator allocator, List columnsList) {
+ FieldVector[] vectors = new FieldVector[columnsList.size()];
+ for (int idx = 0; idx < columnsList.size(); idx += 1) {
+ ColumnInfo column = columnsList.get(idx);
+ Column> vector = column.createVector(allocator);
+ vectors[idx] = vector.vector;
+ columns.put(column.name, vector);
+ }
+
+ this.vsr = VectorSchemaRoot.of(vectors);
+ }
+
+ public Batch createNewBatch(int estimatedRowsCount) {
+ // reset all
+ for (Column> column: columns.values()) {
+ column.allocateNew(estimatedRowsCount);
+ }
+ return new BatchImpl();
+ }
+
+ @Override
+ public void close() {
+ vsr.close();
+ }
+
+ private class BatchImpl implements Batch {
+ private int rowIndex = 0;
+
+ @Override
+ public Row writeNextRow() {
+ return new RowImpl(rowIndex++);
+ }
+
+ @Override
+ public BulkUpsertArrowData buildBatch() throws IOException {
+ vsr.setRowCount(rowIndex);
+ return new BulkUpsertArrowData(serializeSchema(), serializeBatch());
+ }
+
+ private ByteString serializeSchema() throws IOException {
+ try (ByteString.Output out = ByteString.newOutput()) {
+ try (WriteChannel channel = new WriteChannel(Channels.newChannel(out))) {
+ MessageSerializer.serialize(channel, vsr.getSchema());
+ return out.toByteString();
+ }
+ }
+ }
+
+ private ByteString serializeBatch() throws IOException {
+ try (ByteString.Output out = ByteString.newOutput()) {
+ try (WriteChannel channel = new WriteChannel(Channels.newChannel(out))) {
+ VectorUnloader loader = new VectorUnloader(vsr);
+ try (ArrowRecordBatch batch = loader.getRecordBatch()) {
+ MessageSerializer.serialize(channel, batch);
+ return out.toByteString();
+ }
+ }
+ }
+ }
+ }
+
+ private class RowImpl implements Row {
+ private final int rowIndex;
+
+ RowImpl(int rowIndex) {
+ this.rowIndex = rowIndex;
+ }
+
+ private Column> find(String column) {
+ Column> vector = columns.get(column);
+ if (vector == null) {
+ throw new IllegalArgumentException("Column '" + column + "' not found");
+ }
+ return vector;
+ }
+
+ @Override
+ public void writeNull(String column) {
+ find(column).writeNull(rowIndex);
+ }
+
+ @Override
+ public void writeBool(String column, boolean value) {
+ find(column).writeBool(rowIndex, value);
+ }
+
+ @Override
+ public void writeInt8(String column, byte value) {
+ find(column).writeInt8(rowIndex, value);
+ }
+
+ @Override
+ public void writeInt16(String column, short value) {
+ find(column).writeInt16(rowIndex, value);
+ }
+
+ @Override
+ public void writeInt32(String column, int value) {
+ find(column).writeInt32(rowIndex, value);
+ }
+
+ @Override
+ public void writeInt64(String column, long value) {
+ find(column).writeInt64(rowIndex, value);
+ }
+
+ @Override
+ public void writeUint8(String column, int value) {
+ find(column).writeUint8(rowIndex, value);
+ }
+
+ @Override
+ public void writeUint16(String column, int value) {
+ find(column).writeUint16(rowIndex, value);
+ }
+
+ @Override
+ public void writeUint32(String column, long value) {
+ find(column).writeUint32(rowIndex, value);
+ }
+
+ @Override
+ public void writeUint64(String column, long value) {
+ find(column).writeUint64(rowIndex, value);
+ }
+
+ @Override
+ public void writeFloat(String column, float value) {
+ find(column).writeFloat(rowIndex, value);
+ }
+
+ @Override
+ public void writeDouble(String column, double value) {
+ find(column).writeDouble(rowIndex, value);
+ }
+
+ @Override
+ public void writeText(String column, String text) {
+ find(column).writeText(rowIndex, text);
+ }
+
+ @Override
+ public void writeJson(String column, String json) {
+ find(column).writeJson(rowIndex, json);
+ }
+
+ @Override
+ public void writeJsonDocument(String column, String jsonDocument) {
+ find(column).writeJsonDocument(rowIndex, jsonDocument);
+ }
+
+ @Override
+ public void writeBytes(String column, byte[] bytes) {
+ find(column).writeBytes(rowIndex, bytes);
+ }
+
+ @Override
+ public void writeYson(String column, byte[] yson) {
+ find(column).writeYson(rowIndex, yson);
+ }
+
+ @Override
+ public void writeUuid(String column, UUID uuid) {
+ find(column).writeUuid(rowIndex, uuid);
+ }
+
+ @Override
+ public void writeDate(String column, LocalDate date) {
+ find(column).writeDate(rowIndex, date);
+ }
+
+ @Override
+ public void writeDatetime(String column, LocalDateTime datetime) {
+ find(column).writeDatetime(rowIndex, datetime);
+ }
+
+ @Override
+ public void writeTimestamp(String column, Instant instant) {
+ find(column).writeTimestamp(rowIndex, instant);
+ }
+
+ @Override
+ public void writeInterval(String column, Duration interval) {
+ find(column).writeInterval(rowIndex, interval);
+ }
+
+ @Override
+ public void writeDate32(String column, LocalDate date32) {
+ find(column).writeDate32(rowIndex, date32);
+ }
+
+ @Override
+ public void writeDatetime64(String column, LocalDateTime datetime64) {
+ find(column).writeDatetime64(rowIndex, datetime64);
+ }
+
+ @Override
+ public void writeTimestamp64(String column, Instant instant64) {
+ find(column).writeTimestamp64(rowIndex, instant64);
+ }
+
+ @Override
+ public void writeInterval64(String column, Duration interval64) {
+ find(column).writeInterval64(rowIndex, interval64);
+ }
+
+ @Override
+ public void writeDecimal(String column, DecimalValue value) {
+ find(column).writeDecimal(rowIndex, value);
+ }
+ }
+
+ private abstract static class Column {
+ protected final Field field;
+ protected final Type type;
+ protected final T vector;
+
+ Column(Field field, Type type, T vector) {
+ this.field = field;
+ this.type = type;
+ this.vector = vector;
+ }
+
+ protected IllegalStateException error(String method) {
+ return new IllegalStateException("cannot call " + method + ", actual type: " + type);
+ }
+
+ public abstract void allocateNew(int estimated);
+
+ void writeNull(int rowIndex) {
+ if (field.isNullable()) {
+ vector.setNull(rowIndex);
+ } else {
+ throw error("writeNull");
+ }
+ }
+
+ void writeBool(int rowIndex, boolean value) {
+ throw error("writeBool");
+ }
+
+ void writeInt8(int rowIndex, byte value) {
+ throw error("writeInt8");
+ }
+
+ void writeInt16(int rowIndex, short value) {
+ throw error("writeInt16");
+ }
+
+ void writeInt32(int rowIndex, int value) {
+ throw error("writeInt32");
+ }
+
+ void writeInt64(int rowIndex, long value) {
+ throw error("writeInt64");
+ }
+
+ void writeUint8(int rowIndex, int value) {
+ throw error("writeUint8");
+ }
+
+ void writeUint16(int rowIndex, int value) {
+ throw error("writeUint16");
+ }
+
+ void writeUint32(int rowIndex, long value) {
+ throw error("writeUint32");
+ }
+
+ void writeUint64(int rowIndex, long value) {
+ throw error("writeUint64");
+ }
+
+ void writeFloat(int rowIndex, float value) {
+ throw error("writeFloat");
+ }
+
+ void writeDouble(int rowIndex, double value) {
+ throw error("writeDouble");
+ }
+
+ void writeText(int rowIndex, String text) {
+ throw error("writeText");
+ }
+
+ void writeJson(int rowIndex, String json) {
+ throw error("writeJson");
+ }
+
+ void writeJsonDocument(int rowIndex, String jsonDocument) {
+ throw error("writeJsonDocument");
+ }
+
+ void writeBytes(int rowIndex, byte[] bytes) {
+ throw error("writeBytes");
+ }
+
+ void writeYson(int rowIndex, byte[] yson) {
+ throw error("writeYson");
+ }
+
+ void writeUuid(int rowIndex, UUID yson) {
+ throw error("writeUuid");
+ }
+
+ void writeDate(int rowIndex, LocalDate date) {
+ throw error("writeDate");
+ }
+
+ void writeDatetime(int rowIndex, LocalDateTime datetime) {
+ throw error("writeDatetime");
+ }
+
+ void writeTimestamp(int rowIndex, Instant instant) {
+ throw error("writeTimestamp");
+ }
+
+ void writeInterval(int rowIndex, Duration interval) {
+ throw error("writeInterval");
+ }
+
+ void writeDate32(int rowIndex, LocalDate date32) {
+ throw error("writeDate32");
+ }
+
+ void writeDatetime64(int rowIndex, LocalDateTime datetime64) {
+ throw error("writeDatetime64");
+ }
+
+ void writeTimestamp64(int rowIndex, Instant instant64) {
+ throw error("writeTimestamp64");
+ }
+
+ void writeInterval64(int rowIndex, Duration interval64) {
+ throw error("writeInterval64");
+ }
+
+ void writeDecimal(int rowIndex, DecimalValue value) {
+ throw error("writeDecimal");
+ }
+ }
+
+ private static class FixedWidthColumn extends Column {
+
+ FixedWidthColumn(Field field, Type type, T vector) {
+ super(field, type, vector);
+ }
+
+ @Override
+ public void allocateNew(int estimated) {
+ vector.allocateNew(estimated);
+ }
+ }
+
+ private static class VariableWidthColumn extends Column {
+
+ VariableWidthColumn(Field field, Type type, T vector) {
+ super(field, type, vector);
+ }
+
+ @Override
+ public void allocateNew(int estimated) {
+ vector.allocateNew(estimated);
+ }
+ }
+
+ private static class TinyIntColumn extends FixedWidthColumn {
+
+ TinyIntColumn(Type type, BufferAllocator allocator, Field field) {
+ super(field, type, new TinyIntVector(field, allocator));
+ }
+
+ @Override
+ void writeBool(int rowIndex, boolean value) {
+ if (type != PrimitiveType.Bool) {
+ throw error("writeBool");
+ }
+ vector.setSafe(rowIndex, value ? 1 : 0);
+ }
+
+ @Override
+ void writeInt8(int rowIndex, byte value) {
+ if (type != PrimitiveType.Int8) {
+ throw error("writeInt8");
+ }
+ vector.setSafe(rowIndex, value);
+ }
+
+ @Override
+ void writeUint8(int rowIndex, int value) {
+ if (type != PrimitiveType.Uint8) {
+ throw error("writeUint8");
+ }
+ vector.setSafe(rowIndex, value);
+ }
+ }
+
+ private static class SmallIntColumn extends FixedWidthColumn {
+
+ SmallIntColumn(Type type, BufferAllocator allocator, Field field) {
+ super(field, type, new SmallIntVector(field, allocator));
+ }
+
+ @Override
+ void writeInt16(int rowIndex, short value) {
+ if (type != PrimitiveType.Int16) {
+ throw error("writeInt16");
+ }
+ vector.setSafe(rowIndex, value);
+ }
+
+ @Override
+ void writeUint16(int rowIndex, int value) {
+ if (type != PrimitiveType.Uint16) {
+ throw error("writeUint16");
+ }
+ vector.setSafe(rowIndex, value);
+ }
+
+ @Override
+ void writeDate(int rowIndex, LocalDate date) {
+ if (type != PrimitiveType.Date) {
+ throw error("writeDate");
+ }
+ vector.setSafe(rowIndex, (int) date.toEpochDay());
+ }
+ }
+
+ private static class IntColumn extends FixedWidthColumn {
+
+ IntColumn(Type type, BufferAllocator allocator, Field field) {
+ super(field, type, new IntVector(field, allocator));
+ }
+
+ @Override
+ void writeInt32(int rowIndex, int value) {
+ if (type != PrimitiveType.Int32) {
+ throw error("writeInt32");
+ }
+ vector.setSafe(rowIndex, value);
+ }
+
+ @Override
+ void writeUint32(int rowIndex, long value) {
+ if (type != PrimitiveType.Uint32) {
+ throw error("writeUint32");
+ }
+ vector.setSafe(rowIndex, (int) value);
+ }
+
+ @Override
+ void writeDate32(int rowIndex, LocalDate date) {
+ if (type != PrimitiveType.Date32) {
+ throw error("writeDate32");
+ }
+ vector.setSafe(rowIndex, (int) date.toEpochDay());
+ }
+
+ @Override
+ void writeDatetime(int rowIndex, LocalDateTime datetime) {
+ if (type != PrimitiveType.Datetime) {
+ throw error("writeDatetime");
+ }
+ vector.setSafe(rowIndex, (int) datetime.toEpochSecond(ZoneOffset.UTC));
+ }
+ }
+
+ private static class BigIntColumn extends FixedWidthColumn {
+
+ BigIntColumn(Type type, BufferAllocator allocator, Field field) {
+ super(field, type, new BigIntVector(field, allocator));
+ }
+
+ @Override
+ void writeInt64(int rowIndex, long value) {
+ if (type != PrimitiveType.Int64) {
+ throw error("writeInt64");
+ }
+ vector.setSafe(rowIndex, value);
+ }
+
+ @Override
+ void writeUint64(int rowIndex, long value) {
+ if (type != PrimitiveType.Uint64) {
+ throw error("writeUint64");
+ }
+ vector.setSafe(rowIndex, value);
+ }
+
+ @Override
+ void writeDatetime64(int rowIndex, LocalDateTime datetime) {
+ if (type != PrimitiveType.Datetime64) {
+ throw error("writeDatetime64");
+ }
+ vector.setSafe(rowIndex, datetime.toEpochSecond(ZoneOffset.UTC));
+ }
+
+ @Override
+ void writeTimestamp(int rowIndex, Instant value) {
+ if (type != PrimitiveType.Timestamp) {
+ throw error("writeTimestamp");
+ }
+ long micros = value.getEpochSecond() * 1000000L + value.getNano() / 1000;
+ vector.setSafe(rowIndex, micros);
+ }
+
+ @Override
+ void writeTimestamp64(int rowIndex, Instant value) {
+ if (type != PrimitiveType.Timestamp64) {
+ throw error("writeTimestamp64");
+ }
+ long micros = value.getEpochSecond() * 1000000L + value.getNano() / 1000;
+ vector.setSafe(rowIndex, micros);
+ }
+
+ @Override
+ void writeInterval(int rowIndex, Duration duration) {
+ if (type != PrimitiveType.Interval) {
+ throw error("writeInterval");
+ }
+ long micros = duration.getSeconds() * 1000000L + duration.getNano() / 1000;
+ vector.setSafe(rowIndex, micros);
+ }
+
+ @Override
+ void writeInterval64(int rowIndex, Duration duration) {
+ if (type != PrimitiveType.Interval64) {
+ throw error("writeInterval64");
+ }
+ long micros = duration.getSeconds() * 1000000L + duration.getNano() / 1000;
+ vector.setSafe(rowIndex, micros);
+ }
+ }
+
+ private static class FloatColumn extends FixedWidthColumn {
+
+ FloatColumn(Type type, BufferAllocator allocator, Field field) {
+ super(field, type, new Float4Vector(field, allocator));
+ }
+
+ @Override
+ void writeFloat(int rowIndex, float value) {
+ vector.setSafe(rowIndex, value);
+ }
+ }
+
+ private static class DoubleColumn extends FixedWidthColumn {
+
+ DoubleColumn(Type type, BufferAllocator allocator, Field field) {
+ super(field, type, new Float8Vector(field, allocator));
+ }
+
+ @Override
+ void writeDouble(int rowIndex, double value) {
+ vector.setSafe(rowIndex, value);
+ }
+ }
+
+ private static class VarCharColumn extends VariableWidthColumn {
+
+ VarCharColumn(Type type, BufferAllocator allocator, Field field) {
+ super(field, type, new VarCharVector(field, allocator));
+ }
+
+ @Override
+ void writeText(int rowIndex, String text) {
+ if (type != PrimitiveType.Text) {
+ throw error("writeText");
+ }
+ vector.setSafe(rowIndex, text.getBytes(StandardCharsets.UTF_8));
+ }
+
+ @Override
+ void writeJson(int rowIndex, String json) {
+ if (type != PrimitiveType.Json) {
+ throw error("writeJson");
+ }
+ vector.setSafe(rowIndex, json.getBytes(StandardCharsets.UTF_8));
+ }
+
+ @Override
+ void writeJsonDocument(int rowIndex, String jsonDocument) {
+ if (type != PrimitiveType.JsonDocument) {
+ throw error("writeJsonDocument");
+ }
+ vector.setSafe(rowIndex, jsonDocument.getBytes(StandardCharsets.UTF_8));
+ }
+ }
+
+ private static class VarBinaryColumn extends VariableWidthColumn {
+
+ VarBinaryColumn(Type type, BufferAllocator allocator, Field field) {
+ super(field, type, new VarBinaryVector(field, allocator));
+ }
+
+ @Override
+ void writeBytes(int rowIndex, byte[] bytes) {
+ if (type != PrimitiveType.Bytes) {
+ throw error("writeBytes");
+ }
+ vector.setSafe(rowIndex, bytes);
+ }
+
+ @Override
+ void writeYson(int rowIndex, byte[] yson) {
+ if (type != PrimitiveType.Yson) {
+ throw error("writeYson");
+ }
+ vector.setSafe(rowIndex, yson);
+ }
+ }
+
+ private static class FixedBinaryColumn extends FixedWidthColumn {
+
+ FixedBinaryColumn(Type type, BufferAllocator allocator, Field field) {
+ super(field, type, new FixedSizeBinaryVector(field, allocator));
+ }
+
+ /**
+ * @see ProtoValue#newUuid(java.util.UUID)
+ */
+ @Override
+ void writeUuid(int rowIndex, UUID uuid) {
+ if (type != PrimitiveType.Uuid) {
+ throw error("writeUuid");
+ }
+
+ ByteBuffer buf = ByteBuffer.allocate(16);
+
+ long msb = uuid.getMostSignificantBits();
+ long timeLow = (msb & 0xffffffff00000000L) >>> 32;
+ long timeMid = (msb & 0x00000000ffff0000L) << 16;
+ long timeHighAndVersion = (msb & 0x000000000000ffffL) << 48;
+ buf.putLong(LittleEndian.bswap(timeLow | timeMid | timeHighAndVersion));
+ buf.putLong(uuid.getLeastSignificantBits());
+
+ vector.setSafe(rowIndex, buf.array());
+ }
+
+ @Override
+ void writeDecimal(int rowIndex, DecimalValue value) {
+ if (type.getKind() != Type.Kind.DECIMAL) {
+ throw error("writeDecimal");
+ }
+
+ ByteBuffer buf = ByteBuffer.allocate(16);
+ buf.putLong(LittleEndian.bswap(value.getLow()));
+ buf.putLong(LittleEndian.bswap(value.getHigh()));
+
+ vector.setSafe(rowIndex, buf.array());
+ }
+ }
+
+ private static class ColumnInfo {
+ private final String name;
+ private final Type type;
+
+ ColumnInfo(String name, Type type) {
+ this.name = name;
+ this.type = type;
+ }
+
+ private Column> createVector(BufferAllocator allocator) {
+ if (type.getKind() == Type.Kind.OPTIONAL) {
+ return createColumnVector(allocator, type.unwrapOptional(), t -> Field.nullable(name, t));
+ }
+
+ return createColumnVector(allocator, type, t -> Field.notNullable(name, t));
+ }
+ }
+
+ private static Column> createColumnVector(BufferAllocator allocator, Type type, Function gen) {
+ if (type.getKind() == Type.Kind.DECIMAL) {
+ return new FixedBinaryColumn(type, allocator, gen.apply(new ArrowType.FixedSizeBinary(16)));
+ }
+
+ if (type.getKind() == Type.Kind.PRIMITIVE) {
+ switch ((PrimitiveType) type) {
+ case Bool:
+ return new TinyIntColumn(type, allocator, gen.apply(new ArrowType.Int(8, false)));
+
+ case Int8:
+ return new TinyIntColumn(type, allocator, gen.apply(new ArrowType.Int(8, true)));
+ case Int16:
+ return new SmallIntColumn(type, allocator, gen.apply(new ArrowType.Int(16, true)));
+ case Int32:
+ return new IntColumn(type, allocator, gen.apply(new ArrowType.Int(32, true)));
+ case Int64:
+ return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true)));
+
+ case Uint8:
+ return new TinyIntColumn(type, allocator, gen.apply(new ArrowType.Int(8, false)));
+ case Uint16:
+ return new SmallIntColumn(type, allocator, gen.apply(new ArrowType.Int(16, false)));
+ case Uint32:
+ return new IntColumn(type, allocator, gen.apply(new ArrowType.Int(32, false)));
+ case Uint64:
+ return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, false)));
+
+ case Float:
+ return new FloatColumn(type, allocator, gen.apply(
+ new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
+ ));
+ case Double:
+ return new DoubleColumn(type, allocator, gen.apply(
+ new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
+ ));
+
+ case Text:
+ case Json:
+ case JsonDocument:
+ return new VarCharColumn(type, allocator, gen.apply(new ArrowType.Utf8()));
+
+ case Bytes:
+ case Yson:
+ return new VarBinaryColumn(type, allocator, gen.apply(new ArrowType.Binary()));
+
+ case Uuid:
+ return new FixedBinaryColumn(type, allocator, gen.apply(new ArrowType.FixedSizeBinary(16)));
+
+ case Date:
+ return new SmallIntColumn(type, allocator, gen.apply(new ArrowType.Int(16, false)));
+ case Datetime:
+ return new IntColumn(type, allocator, gen.apply(new ArrowType.Int(32, false)));
+ case Timestamp:
+ return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true)));
+ case Interval:
+ return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Duration(TimeUnit.MICROSECOND)));
+
+ case Date32:
+ return new IntColumn(type, allocator, gen.apply(new ArrowType.Int(32, true)));
+ case Datetime64:
+ return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true)));
+ case Timestamp64:
+ return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true)));
+ case Interval64:
+ return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true)));
+
+ default:
+ break;
+ }
+ }
+
+ throw new IllegalArgumentException("Type " + type + " is not supported in ArrowWriter");
+ }
+
+ public static Schema newSchema() {
+ return new Schema();
+ }
+
+ public static class Schema {
+ private final List columns = new ArrayList<>();
+
+ public Schema addColumn(String name, Type type) {
+ this.columns.add(new ColumnInfo(name, type));
+ return this;
+ }
+
+ public Schema addNullableColumn(String name, Type type) {
+ return addColumn(name, type.makeOptional());
+ }
+
+ public ApacheArrowWriter createWriter(BufferAllocator allocator) {
+ return new ApacheArrowWriter(allocator, columns);
+ }
+ }
+}
diff --git a/table/src/main/java/tech/ydb/table/query/BulkUpsertArrowData.java b/table/src/main/java/tech/ydb/table/query/BulkUpsertArrowData.java
new file mode 100644
index 000000000..1b2d23e8b
--- /dev/null
+++ b/table/src/main/java/tech/ydb/table/query/BulkUpsertArrowData.java
@@ -0,0 +1,35 @@
+package tech.ydb.table.query;
+
+import com.google.protobuf.ByteString;
+
+import tech.ydb.proto.formats.YdbFormats;
+import tech.ydb.proto.table.YdbTable;
+
+/**
+ *
+ * @author Aleksandr Gorshenin
+ */
+public class BulkUpsertArrowData implements BulkUpsertData {
+ private final ByteString schema;
+ private final ByteString data;
+
+ public BulkUpsertArrowData(ByteString schema, ByteString data) {
+ this.schema = schema;
+ this.data = data;
+ }
+
+ public ByteString getSchema() {
+ return schema;
+ }
+
+ public ByteString getData() {
+ return data;
+ }
+
+ @Override
+ public void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder) {
+ builder.setArrowBatchSettings(
+ YdbFormats.ArrowBatchSettings.newBuilder().setSchema(schema).build()
+ ).setData(data);
+ }
+}
diff --git a/table/src/main/java/tech/ydb/table/query/BulkUpsertData.java b/table/src/main/java/tech/ydb/table/query/BulkUpsertData.java
index 73d170ffc..96dee3b1d 100644
--- a/table/src/main/java/tech/ydb/table/query/BulkUpsertData.java
+++ b/table/src/main/java/tech/ydb/table/query/BulkUpsertData.java
@@ -9,21 +9,14 @@
*
* @author Aleksandr Gorshenin
*/
-public class BulkUpsertData {
- private final ValueProtos.TypedValue rows;
+public interface BulkUpsertData {
+ void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder);
- public BulkUpsertData(ListValue rows) {
- this.rows = ValueProtos.TypedValue.newBuilder()
- .setType(rows.getType().toPb())
- .setValue(rows.toPb())
- .build();
+ static BulkUpsertData fromRows(ListValue list) {
+ return new BulkUpsertProtoData(list);
}
- public BulkUpsertData(ValueProtos.TypedValue rows) {
- this.rows = rows;
- }
-
- public void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder) {
- builder.setRows(rows);
+ static BulkUpsertData fromProto(ValueProtos.TypedValue rows) {
+ return new BulkUpsertProtoData(rows);
}
}
diff --git a/table/src/main/java/tech/ydb/table/query/BulkUpsertProtoData.java b/table/src/main/java/tech/ydb/table/query/BulkUpsertProtoData.java
new file mode 100644
index 000000000..17d575f46
--- /dev/null
+++ b/table/src/main/java/tech/ydb/table/query/BulkUpsertProtoData.java
@@ -0,0 +1,29 @@
+package tech.ydb.table.query;
+
+import tech.ydb.proto.ValueProtos;
+import tech.ydb.proto.table.YdbTable;
+import tech.ydb.table.values.ListValue;
+
+/**
+ *
+ * @author Aleksandr Gorshenin
+ */
+public class BulkUpsertProtoData implements BulkUpsertData {
+ private final ValueProtos.TypedValue rows;
+
+ public BulkUpsertProtoData(ListValue rows) {
+ this.rows = ValueProtos.TypedValue.newBuilder()
+ .setType(rows.getType().toPb())
+ .setValue(rows.toPb())
+ .build();
+ }
+
+ public BulkUpsertProtoData(ValueProtos.TypedValue rows) {
+ this.rows = rows;
+ }
+
+ @Override
+ public void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder) {
+ builder.setRows(rows);
+ }
+}
diff --git a/table/src/main/java/tech/ydb/table/utils/LittleEndian.java b/table/src/main/java/tech/ydb/table/utils/LittleEndian.java
index b1c7fc56f..816d1ba1e 100644
--- a/table/src/main/java/tech/ydb/table/utils/LittleEndian.java
+++ b/table/src/main/java/tech/ydb/table/utils/LittleEndian.java
@@ -8,6 +8,9 @@ private LittleEndian() { }
/**
* Reverses the byte order of a long value.
+ *
+ * @param v long value
+ * @return reversed long value
*/
public static long bswap(long v) {
return
diff --git a/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java b/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java
new file mode 100644
index 000000000..19254ccdc
--- /dev/null
+++ b/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java
@@ -0,0 +1,438 @@
+package tech.ydb.table.integration;
+
+import java.math.BigInteger;
+import java.nio.charset.StandardCharsets;
+import java.time.Duration;
+import java.time.Instant;
+import java.time.LocalDate;
+import java.time.LocalDateTime;
+import java.time.ZoneOffset;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+import java.util.UUID;
+import java.util.function.BiConsumer;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+import com.google.common.hash.Hashing;
+import org.junit.Assert;
+
+import tech.ydb.table.description.TableColumn;
+import tech.ydb.table.description.TableDescription;
+import tech.ydb.table.query.ApacheArrowWriter;
+import tech.ydb.table.result.ResultSetReader;
+import tech.ydb.table.result.ValueReader;
+import tech.ydb.table.values.DecimalType;
+import tech.ydb.table.values.DecimalValue;
+import tech.ydb.table.values.ListType;
+import tech.ydb.table.values.ListValue;
+import tech.ydb.table.values.PrimitiveType;
+import tech.ydb.table.values.PrimitiveValue;
+import tech.ydb.table.values.StructType;
+import tech.ydb.table.values.StructValue;
+import tech.ydb.table.values.Type;
+import tech.ydb.table.values.Value;
+
+/**
+ *
+ * @author Aleksandr Gorshenin
+ */
+public class AllTypesRecord {
+ private static final DecimalType YDB_DECIMAL = DecimalType.getDefault();
+ private static final DecimalType BANK_DECIMAL = DecimalType.of(31, 9);
+ private static final DecimalType BIG_DECIMAL = DecimalType.of(35, 0);
+
+ // Keys
+ private final long id1;
+ private final int id2;
+ private final byte[] payload;
+ private final int length;
+ private final String hash ;
+
+ // All types
+ private Byte v_int8;
+ private Short v_int16;
+ private Integer v_int32;
+ private Long v_int64;
+
+ private Integer v_uint8;
+ private Integer v_uint16;
+ private Long v_uint32;
+ private Long v_uint64;
+
+ private Boolean v_bool;
+ private Float v_float;
+ private Double v_double;
+
+ private String v_text;
+ private String v_json;
+ private String v_jsdoc;
+
+ private byte[] v_bytes;
+ private byte[] v_yson;
+
+ private UUID v_uuid;
+
+ private LocalDate v_date;
+ private LocalDateTime v_datetime;
+ private Instant v_timestamp;
+ private Duration v_interval;
+
+ private LocalDate v_date32;
+ private LocalDateTime v_datetime64;
+ private Instant v_timestamp64;
+ private Duration v_interval64;
+
+ private DecimalValue v_ydb_decimal;
+ private DecimalValue v_bank_decimal;
+ private DecimalValue v_big_decimal;
+
+ private AllTypesRecord(long id1, int id2, byte[] payload) {
+ // Keys
+ this.id1 = id1;
+ this.id2 = id2;
+ this.length = payload.length;
+ this.payload = payload;
+ this.hash = Hashing.sha256().hashBytes(payload).toString();
+ }
+
+ private void assertValue(Set columns, String key, int idx, Function reader,
+ ResultSetReader rs, T value) {
+ if (!columns.contains(key)) {
+ return;
+ }
+
+ String msg = "Row " + idx + " " + key;
+ Assert.assertEquals(msg, value != null, rs.getColumn(key).isOptionalItemPresent());
+
+ if (value != null) {
+ if (value instanceof byte[]) {
+ Assert.assertArrayEquals(msg, (byte[]) value, (byte[]) reader.apply(rs.getColumn(key)));
+ } else {
+ Assert.assertEquals(msg, value, reader.apply(rs.getColumn(key)));
+ }
+ }
+ }
+
+ public void assertRow(Set columns, int idx, ResultSetReader rs) {
+ // keys are required
+ Assert.assertEquals("Row " + idx + " id1", id1, rs.getColumn("id1").getUint64());
+ Assert.assertEquals("Row " + idx + " id2", id2, rs.getColumn("id2").getInt64());
+ Assert.assertArrayEquals("Row " + idx + " payload", payload, rs.getColumn("payload").getBytes());
+ Assert.assertEquals("Row " + idx + " length", length, rs.getColumn("length").getUint32());
+ Assert.assertEquals("Row " + idx + " hash", hash, rs.getColumn("hash").getText());
+
+ // other columns may be skipped or be empty
+ assertValue(columns, "Int8", idx, ValueReader::getInt8, rs, v_int8);
+ assertValue(columns, "Int16", idx, ValueReader::getInt16, rs, v_int16);
+ assertValue(columns, "Int32", idx, ValueReader::getInt32, rs, v_int32);
+ assertValue(columns, "Int64", idx, ValueReader::getInt64, rs, v_int64);
+
+ assertValue(columns, "Uint8", idx, ValueReader::getUint8, rs, v_uint8);
+ assertValue(columns, "Uint16", idx, ValueReader::getUint16, rs, v_uint16);
+ assertValue(columns, "Uint32", idx, ValueReader::getUint32, rs, v_uint32);
+ assertValue(columns, "Uint64", idx, ValueReader::getUint64, rs, v_uint64);
+
+ assertValue(columns, "Bool", idx, ValueReader::getBool, rs, v_bool);
+ assertValue(columns, "Float", idx, ValueReader::getFloat, rs, v_float);
+ assertValue(columns, "Double", idx, ValueReader::getDouble, rs, v_double);
+
+ assertValue(columns, "Text", idx, ValueReader::getText, rs, v_text);
+ assertValue(columns, "Json", idx, ValueReader::getJson, rs, v_json);
+ assertValue(columns, "JsonDocument", idx, ValueReader::getJsonDocument, rs, v_jsdoc);
+
+ assertValue(columns, "Uuid", idx, ValueReader::getUuid, rs, v_uuid);
+
+ assertValue(columns, "Bytes", idx, ValueReader::getBytes, rs, v_bytes);
+ assertValue(columns, "Yson", idx, ValueReader::getYson, rs, v_yson);
+
+ assertValue(columns, "Date", idx, ValueReader::getDate, rs, v_date);
+ assertValue(columns, "Datetime", idx, ValueReader::getDatetime, rs, v_datetime);
+ assertValue(columns, "Timestamp", idx, ValueReader::getTimestamp, rs, v_timestamp);
+ assertValue(columns, "Interval", idx, ValueReader::getInterval, rs, v_interval);
+
+ assertValue(columns, "Date32", idx, ValueReader::getDate32, rs, v_date32);
+ assertValue(columns, "Datetime64", idx, ValueReader::getDatetime64, rs, v_datetime64);
+ assertValue(columns, "Timestamp64", idx, ValueReader::getTimestamp64, rs, v_timestamp64);
+ assertValue(columns, "Interval64", idx, ValueReader::getInterval64, rs, v_interval64);
+
+ assertValue(columns, "YdbDecimal", idx, ValueReader::getDecimal, rs, v_ydb_decimal);
+ assertValue(columns, "BankDecimal", idx, ValueReader::getDecimal, rs, v_bank_decimal);
+ assertValue(columns, "BigDecimal", idx, ValueReader::getDecimal, rs, v_big_decimal);
+ }
+
+ private Value> makePb(Type type, Function> func, T value) {
+ return value != null ? func.apply(value) : type.makeOptional().emptyValue();
+ }
+
+ private Map> toPb(List columnOnly) {
+ Map> struct = new HashMap<>();
+
+ BiConsumer> write = (key, value) -> {
+ if (columnOnly.contains(key)) {
+ struct.put(key, value);
+ }
+ };
+
+ // keys are required
+ struct.put("id1", PrimitiveValue.newUint64(id1));
+ struct.put("id2", PrimitiveValue.newInt64(id2));
+ struct.put("payload", PrimitiveValue.newBytesOwn(payload));
+ struct.put("length", PrimitiveValue.newUint32(length));
+ struct.put("hash", PrimitiveValue.newText(hash));
+
+ // other columns may be skipped or be empty
+ write.accept("Int8", makePb(PrimitiveType.Int8, PrimitiveValue::newInt8, v_int8));
+ write.accept("Int16", makePb(PrimitiveType.Int16, PrimitiveValue::newInt16, v_int16));
+ write.accept("Int32", makePb(PrimitiveType.Int32, PrimitiveValue::newInt32, v_int32));
+ write.accept("Int64", makePb(PrimitiveType.Int64, PrimitiveValue::newInt64, v_int64));
+
+ write.accept("Uint8", makePb(PrimitiveType.Uint8, PrimitiveValue::newUint8, v_uint8));
+ write.accept("Uint16", makePb(PrimitiveType.Uint16, PrimitiveValue::newUint16, v_uint16));
+ write.accept("Uint32", makePb(PrimitiveType.Uint32, PrimitiveValue::newUint32, v_uint32));
+ write.accept("Uint64", makePb(PrimitiveType.Uint64, PrimitiveValue::newUint64, v_uint64));
+
+ write.accept("Bool", makePb(PrimitiveType.Bool, PrimitiveValue::newBool, v_bool));
+ write.accept("Float", makePb(PrimitiveType.Float, PrimitiveValue::newFloat, v_float));
+ write.accept("Double", makePb(PrimitiveType.Double, PrimitiveValue::newDouble, v_double));
+
+ write.accept("Text", makePb(PrimitiveType.Text, PrimitiveValue::newText, v_text));
+ write.accept("Json", makePb(PrimitiveType.Json, PrimitiveValue::newJson, v_json));
+ write.accept("JsonDocument", makePb(PrimitiveType.JsonDocument, PrimitiveValue::newJsonDocument, v_jsdoc));
+
+ write.accept("Bytes", makePb(PrimitiveType.Bytes, PrimitiveValue::newBytes, v_bytes));
+ write.accept("Yson", makePb(PrimitiveType.Yson, PrimitiveValue::newYson, v_yson));
+
+ write.accept("Uuid", makePb(PrimitiveType.Uuid, PrimitiveValue::newUuid, v_uuid));
+
+ write.accept("Date", makePb(PrimitiveType.Date, PrimitiveValue::newDate, v_date));
+ write.accept("Datetime", makePb(PrimitiveType.Datetime, PrimitiveValue::newDatetime, v_datetime));
+ write.accept("Timestamp", makePb(PrimitiveType.Timestamp, PrimitiveValue::newTimestamp, v_timestamp));
+ write.accept("Interval", makePb(PrimitiveType.Interval, PrimitiveValue::newInterval, v_interval));
+
+ write.accept("Date32", makePb(PrimitiveType.Date32, PrimitiveValue::newDate32, v_date32));
+ write.accept("Datetime64", makePb(PrimitiveType.Datetime64, PrimitiveValue::newDatetime64, v_datetime64));
+ write.accept("Timestamp64", makePb(PrimitiveType.Timestamp64, PrimitiveValue::newTimestamp64, v_timestamp64));
+ write.accept("Interval64", makePb(PrimitiveType.Interval64, PrimitiveValue::newInterval64, v_interval64));
+
+ write.accept("YdbDecimal", makePb(YDB_DECIMAL, Function.identity(), v_ydb_decimal));
+ write.accept("BankDecimal", makePb(BANK_DECIMAL, Function.identity(), v_bank_decimal));
+ write.accept("BigDecimal", makePb(BIG_DECIMAL, Function.identity(), v_big_decimal));
+
+ return struct;
+ }
+
+ private void writeNullable(Set columns, String key, ApacheArrowWriter.Row row,
+ BiConsumer writer, T value) {
+ if (!columns.contains(key)) {
+ return;
+ }
+
+ if (value == null) {
+ row.writeNull(key);
+ } else {
+ writer.accept(key, value);
+ }
+ }
+
+ public void writeToApacheArrow(Set columnNames, ApacheArrowWriter.Row row) {
+ // keys are required
+ row.writeUint64("id1", id1);
+ row.writeInt64("id2", id2);
+ row.writeBytes("payload", payload);
+ row.writeUint32("length", length);
+ row.writeText("hash", hash);
+
+ // other columns may be skipped or be empty
+ writeNullable(columnNames, "Int8", row, row::writeInt8, v_int8);
+ writeNullable(columnNames, "Int16", row, row::writeInt16, v_int16);
+ writeNullable(columnNames, "Int32", row, row::writeInt32, v_int32);
+ writeNullable(columnNames, "Int64", row, row::writeInt64, v_int64);
+
+ writeNullable(columnNames, "Uint8", row, row::writeUint8, v_uint8);
+ writeNullable(columnNames, "Uint16", row, row::writeUint16, v_uint16);
+ writeNullable(columnNames, "Uint32", row, row::writeUint32, v_uint32);
+ writeNullable(columnNames, "Uint64", row, row::writeUint64, v_uint64);
+
+ writeNullable(columnNames, "Bool", row, row::writeBool, v_bool);
+
+ writeNullable(columnNames, "Float", row, row::writeFloat, v_float);
+ writeNullable(columnNames, "Double", row, row::writeDouble, v_double);
+
+ writeNullable(columnNames, "Text", row, row::writeText, v_text);
+ writeNullable(columnNames, "Json", row, row::writeJson, v_json);
+ writeNullable(columnNames, "JsonDocument", row, row::writeJsonDocument, v_jsdoc);
+
+ writeNullable(columnNames, "Bytes", row, row::writeBytes, v_bytes);
+ writeNullable(columnNames, "Yson", row, row::writeYson, v_yson);
+
+ writeNullable(columnNames, "Uuid", row, row::writeUuid, v_uuid);
+
+ writeNullable(columnNames, "Date", row, row::writeDate, v_date);
+ writeNullable(columnNames, "Datetime", row, row::writeDatetime, v_datetime);
+ writeNullable(columnNames, "Timestamp", row, row::writeTimestamp, v_timestamp);
+ writeNullable(columnNames, "Interval", row, row::writeInterval, v_interval);
+
+ writeNullable(columnNames, "Date32", row, row::writeDate32, v_date32);
+ writeNullable(columnNames, "Datetime64", row, row::writeDatetime64, v_datetime64);
+ writeNullable(columnNames, "Timestamp64", row, row::writeTimestamp64, v_timestamp64);
+ writeNullable(columnNames, "Interval64", row, row::writeInterval64, v_interval64);
+
+ writeNullable(columnNames, "YdbDecimal", row, row::writeDecimal, v_ydb_decimal);
+ writeNullable(columnNames, "BankDecimal", row, row::writeDecimal, v_bank_decimal);
+ writeNullable(columnNames, "BigDecimal", row, row::writeDecimal, v_big_decimal);
+ }
+
+ public static DecimalValue randomDecimal(DecimalType type, Random rnd) {
+ int kind = rnd.nextInt(1000);
+ switch (kind) {
+ case 0: return type.getNegInf();
+ case 499: return type.newValue(0);
+ case 999: return type.getInf();
+ default:
+ break;
+ }
+ BigInteger unscaled = new BigInteger(1 + rnd.nextInt(type.getPrecision() * 3 - 1), rnd);
+ if (kind < 500) {
+ unscaled = unscaled.negate();
+ }
+ return type.newValueUnscaled(unscaled);
+ }
+
+ private static T nullable(Random rnd, T value) {
+ return rnd.nextInt(10) == 0 ? null : value;
+ }
+
+ public static AllTypesRecord random(long id1, int id2, Random rnd) {
+ int length = 10 + rnd.nextInt(256);
+ byte[] payload = new byte[length];
+ rnd.nextBytes(payload);
+
+ AllTypesRecord r = new AllTypesRecord(id1, id2, payload);
+
+ // All types
+ r.v_int8 = nullable(rnd, (byte) rnd.nextInt());
+ r.v_int16 = nullable(rnd, (short) rnd.nextInt());
+ r.v_int32 = nullable(rnd, rnd.nextInt());
+ r.v_int64 = nullable(rnd, rnd.nextLong());
+
+ r.v_uint8 = nullable(rnd, rnd.nextInt() & 0xFF);
+ r.v_uint16 = nullable(rnd, rnd.nextInt() & 0xFFFF);
+ r.v_uint32 = nullable(rnd, rnd.nextLong() & 0xFFFFFFFFL);
+ r.v_uint64 = nullable(rnd, rnd.nextLong());
+
+ r.v_bool = nullable(rnd, rnd.nextBoolean());
+ r.v_float = nullable(rnd, rnd.nextFloat());
+ r.v_double = nullable(rnd, rnd.nextDouble());
+
+ r.v_text = nullable(rnd, "Text" + rnd.nextInt(1000));
+ r.v_json = nullable(rnd, "{\"json\":" + (1000 + rnd.nextInt(1000)) + "}");
+ r.v_jsdoc = nullable(rnd, "{\"document\":" + (2000 + rnd.nextInt(1000)) + "}");
+
+ r.v_bytes = nullable(rnd, ("Bytes " + (3000 + rnd.nextInt(1000))).getBytes(StandardCharsets.UTF_8));
+ r.v_yson = nullable(rnd, ("{yson=" + (4000 + rnd.nextInt(1000)) + "}").getBytes(StandardCharsets.UTF_8));
+
+ r.v_uuid = nullable(rnd, UUID.nameUUIDFromBytes(("UUID" + rnd.nextInt()).getBytes(StandardCharsets.UTF_8)));
+
+ r.v_date = nullable(rnd, LocalDate.ofEpochDay(rnd.nextInt(5000)));
+ r.v_datetime = nullable(rnd, LocalDateTime.ofEpochSecond(0x60FFFFFF & rnd.nextLong(), 0, ZoneOffset.UTC));
+ r.v_timestamp = nullable(rnd, Instant.ofEpochSecond(0x3FFFFFFFL & rnd.nextLong(),
+ rnd.nextInt(1000000) * 1000));
+ r.v_interval = nullable(rnd, Duration.ofNanos((0x7FFFFFFFFFL & rnd.nextLong() - 0x4000000000L) * 1000));
+
+ r.v_date32 = nullable(rnd, LocalDate.ofEpochDay(rnd.nextInt(5000) - 2500));
+ r.v_datetime64 = nullable(rnd, LocalDateTime.ofEpochSecond(0x60FFFFFF & rnd.nextLong() - 0x30FFFFFF, 0,
+ ZoneOffset.UTC));
+ r.v_timestamp64 = nullable(rnd, Instant.ofEpochSecond(0x7FFFFFFFFFL & rnd.nextLong() - 0x4000000000L,
+ rnd.nextInt(1000000) * 1000));
+ r.v_interval64 = nullable(rnd, Duration.ofNanos((0x7FFFFFFFFFL & rnd.nextLong() - 0x4000000000L) * 1000));
+
+ r.v_ydb_decimal = nullable(rnd, randomDecimal(YDB_DECIMAL, rnd));
+ r.v_bank_decimal = nullable(rnd, randomDecimal(BANK_DECIMAL, rnd));
+ r.v_big_decimal = nullable(rnd, randomDecimal(BIG_DECIMAL, rnd));
+
+ return r;
+ }
+
+ public static List randomBatch(int id1, int id2_start, int count) {
+ Random rnd = new Random(id1 * count + id2_start);
+ List batch = new ArrayList<>(count);
+ for (int idx = 0; idx < count; idx += 1) {
+ batch.add(random(id1, id2_start + idx, rnd));
+ }
+ return batch;
+ }
+
+ public static ListValue createProtobufBatch(TableDescription desc, List batch) {
+ List columnNames = desc.getColumns().stream().map(TableColumn::getName).collect(Collectors.toList());
+ List columnTypes = desc.getColumns().stream().map(TableColumn::getType).collect(Collectors.toList());
+
+ StructType type = StructType.of(columnNames, columnTypes);
+ List values = batch.stream().map(r -> type.newValue(r.toPb(columnNames)))
+ .collect(Collectors.toList());
+
+ return ListType.of(type).newValue(values);
+ }
+
+ public static TableDescription createTableDescription(boolean isColumnShard) {
+ TableDescription.Builder builder = TableDescription.newBuilder()
+ .addNonnullColumn("id1", PrimitiveType.Uint64)
+ .addNonnullColumn("id2", PrimitiveType.Int64)
+ .addNonnullColumn("payload", PrimitiveType.Bytes)
+ .addNonnullColumn("length", PrimitiveType.Uint32)
+ .addNonnullColumn("hash", PrimitiveType.Text)
+
+ .addNullableColumn("Int8", PrimitiveType.Int8)
+ .addNullableColumn("Int16", PrimitiveType.Int16)
+ .addNullableColumn("Int32", PrimitiveType.Int32)
+ .addNullableColumn("Int64", PrimitiveType.Int64)
+ .addNullableColumn("Uint8", PrimitiveType.Uint8)
+ .addNullableColumn("Uint16", PrimitiveType.Uint16)
+ .addNullableColumn("Uint32", PrimitiveType.Uint32)
+ .addNullableColumn("Uint64", PrimitiveType.Uint64)
+ .addNullableColumn("Bool", PrimitiveType.Bool)
+ .addNullableColumn("Float", PrimitiveType.Float)
+ .addNullableColumn("Double", PrimitiveType.Double)
+ .addNullableColumn("Text", PrimitiveType.Text)
+ .addNullableColumn("Json", PrimitiveType.Json)
+ .addNullableColumn("JsonDocument", PrimitiveType.JsonDocument)
+ .addNullableColumn("Bytes", PrimitiveType.Bytes)
+ .addNullableColumn("Yson", PrimitiveType.Yson);
+
+ // https://github.com/ydb-platform/ydb/issues/13047
+ if (!isColumnShard) {
+ builder = builder.addNullableColumn("Uuid", PrimitiveType.Uuid);
+ }
+
+ builder = builder
+ .addNullableColumn("Date", PrimitiveType.Date)
+ .addNullableColumn("Datetime", PrimitiveType.Datetime)
+ .addNullableColumn("Timestamp", PrimitiveType.Timestamp);
+
+ // https://github.com/ydb-platform/ydb/issues/13050
+ if (!isColumnShard) {
+ builder = builder.addNullableColumn("Interval", PrimitiveType.Interval);
+ }
+
+ builder = builder
+ .addNullableColumn("Date32", PrimitiveType.Date32)
+ .addNullableColumn("Datetime64", PrimitiveType.Datetime64)
+ .addNullableColumn("Timestamp64", PrimitiveType.Timestamp64)
+ .addNullableColumn("Interval64", PrimitiveType.Interval64)
+ .addNullableColumn("YdbDecimal", YDB_DECIMAL)
+ .addNullableColumn("BankDecimal", BANK_DECIMAL)
+ .addNullableColumn("BigDecimal", BIG_DECIMAL);
+
+ if (isColumnShard) {
+ builder = builder.setPrimaryKey("hash").setStoreType(TableDescription.StoreType.COLUMN);
+ } else {
+ builder = builder.setPrimaryKeys("id1", "id2").setStoreType(TableDescription.StoreType.ROW);
+ }
+
+ return builder.build();
+ }
+}
diff --git a/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java b/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java
new file mode 100644
index 000000000..b67d990aa
--- /dev/null
+++ b/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java
@@ -0,0 +1,254 @@
+package tech.ydb.table.integration;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CompletionException;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
+import java.util.stream.Collectors;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.ClassRule;
+import org.junit.Test;
+
+import tech.ydb.core.grpc.GrpcReadStream;
+import tech.ydb.table.SessionRetryContext;
+import tech.ydb.table.description.TableColumn;
+import tech.ydb.table.description.TableDescription;
+import tech.ydb.table.impl.SimpleTableClient;
+import tech.ydb.table.query.ApacheArrowWriter;
+import tech.ydb.table.query.BulkUpsertData;
+import tech.ydb.table.query.Params;
+import tech.ydb.table.result.ResultSetReader;
+import tech.ydb.table.rpc.grpc.GrpcTableRpc;
+import tech.ydb.table.settings.ExecuteScanQuerySettings;
+import tech.ydb.table.settings.ExecuteSchemeQuerySettings;
+import tech.ydb.table.values.ListValue;
+import tech.ydb.table.values.PrimitiveValue;
+import tech.ydb.test.junit4.GrpcTransportRule;
+
+/**
+ *
+ * @author Aleksandr Gorshenin
+ */
+public class BulkUpsertTest {
+ @ClassRule
+ public static final GrpcTransportRule YDB = new GrpcTransportRule();
+ private static final SimpleTableClient client = SimpleTableClient.newClient(GrpcTableRpc.useTransport(YDB)).build();
+ private static final SessionRetryContext retryCtx = SessionRetryContext.create(client)
+ .idempotent(false)
+ .build();
+
+ private static final String TEST_TABLE = "arrow/test_table";
+ private static final String DROP_TABLE_YQL = "DROP TABLE IF EXISTS `" + TEST_TABLE + "`;";
+ private static final String SELECT_TABLE_YQL = "DECLARE $id1 AS Uint64; SELECT * FROM `" + TEST_TABLE + "` "
+ + "WHERE id1 = $id1 ORDER BY id2";
+
+ private static String tablePath() {
+ return YDB.getDatabase() + "/" + TEST_TABLE;
+ }
+
+ @After
+ public void cleanTable() {
+ retryCtx.supplyStatus(session -> session.executeSchemeQuery(DROP_TABLE_YQL, new ExecuteSchemeQuerySettings()))
+ .join().expectSuccess("cannot drop table");
+ }
+
+ private static void createTable(TableDescription table) {
+ retryCtx.supplyStatus(s -> s.createTable(tablePath(), table))
+ .join().expectSuccess("Cannot create table");
+ }
+
+ private static void bulkUpsert(BulkUpsertData data) {
+ retryCtx.supplyStatus(session -> session.executeBulkUpsert(tablePath(), data))
+ .join().expectSuccess("bulk upsert problem in table " + tablePath());
+ }
+
+ private static void bulkUpsert(ListValue rows) {
+ retryCtx.supplyStatus(session -> session.executeBulkUpsert(tablePath(), rows))
+ .join().expectSuccess("bulk upsert problem in table " + tablePath());
+ }
+
+ private static int readTable(long id1, BiConsumer validator) {
+ AtomicInteger count = new AtomicInteger();
+
+ try {
+ retryCtx.supplyStatus(session -> {
+ count.set(0);
+
+ GrpcReadStream stream = session.executeScanQuery(SELECT_TABLE_YQL,
+ Params.of("$id1", PrimitiveValue.newUint64(id1)),
+ ExecuteScanQuerySettings.newBuilder().build()
+ );
+
+ return stream.start((rs) -> {
+ while (rs.next()) {
+ int idx = count.getAndIncrement();
+ validator.accept(idx, rs);
+ }
+ });
+ }).join().expectSuccess("Cannot read table " + TEST_TABLE);
+ } catch (CompletionException ex) {
+ if (ex.getCause() instanceof AssertionError) {
+ throw (AssertionError) ex.getCause();
+ }
+ throw ex;
+ }
+
+ return count.get();
+ }
+
+ @Test
+ public void writeProtobufToDataShardTest() {
+ // Create table
+ TableDescription table = AllTypesRecord.createTableDescription(false);
+ createTable(table);
+
+ Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet());
+
+ // Write & read batch of 1000 records with id1 = 1
+ List batch1 = AllTypesRecord.randomBatch(1, 1, 1000);
+ bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch1));
+
+ int rows1Count = readTable(1, (idx, rs) -> {
+ Assert.assertTrue("Unexpected row index", idx < batch1.size());
+ batch1.get(idx).assertRow(columnNames, idx, rs);
+ });
+ Assert.assertEquals(1000, rows1Count);
+
+ // Write & read batch of 2000 records with id1 = 2
+ List batch2 = AllTypesRecord.randomBatch(2, 1, 2000);
+ bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch2));
+
+ int rows2Count = readTable(2, (idx, rs) -> {
+ Assert.assertTrue("Unexpected row index", idx < batch2.size());
+ batch2.get(idx).assertRow(columnNames, idx, rs);
+ });
+ Assert.assertEquals(2000, rows2Count);
+ }
+
+ @Test
+ public void writeProtobufToColumnShardTable() {
+ // Create table
+ TableDescription table = AllTypesRecord.createTableDescription(true);
+ Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet());
+
+ createTable(table);
+
+ // Write & read batch of 2500 records with id1 = 1
+ List batch1 = AllTypesRecord.randomBatch(1, 1, 2500);
+ bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch1));
+ // Read table and validate data
+ int rows1Count = readTable(1, (idx, rs) -> {
+ Assert.assertTrue("Unexpected row index", idx < batch1.size());
+ batch1.get(idx).assertRow(columnNames, idx, rs);
+ });
+ Assert.assertEquals(2500, rows1Count);
+
+ // Write & read batch of 5000 records with id1 = 2
+ List batch2 = AllTypesRecord.randomBatch(2, 1, 5000);
+ bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch2));
+
+ int rows2Count = readTable(2, (idx, rs) -> {
+ Assert.assertTrue("Unexpected row index", idx < batch2.size());
+ batch2.get(idx).assertRow(columnNames, idx, rs);
+ });
+ Assert.assertEquals(5000, rows2Count);
+ }
+
+ @Test
+ public void writeApacheArrowToDataShardTest() {
+ // Create table
+ TableDescription table = AllTypesRecord.createTableDescription(false);
+ retryCtx.supplyStatus(s -> s.createTable(tablePath(), table)).join().expectSuccess("Cannot create table");
+
+ Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet());
+
+ ApacheArrowWriter.Schema schema = ApacheArrowWriter.newSchema();
+ table.getColumns().forEach(column -> schema.addColumn(column.getName(), column.getType()));
+
+ // Create batch of 1000 records
+ List batch1 = AllTypesRecord.randomBatch(1, 1, 1000);
+ // Create batch of 2000 records
+ List batch2 = AllTypesRecord.randomBatch(2, 1, 2000);
+
+ try (BufferAllocator allocator = new RootAllocator()) {
+ try (ApacheArrowWriter writer = schema.createWriter(allocator)) {
+ // create batch with estimated size
+ ApacheArrowWriter.Batch data1 = writer.createNewBatch(1000);
+ batch1.forEach(r -> r.writeToApacheArrow(columnNames, data1.writeNextRow()));
+ bulkUpsert(data1.buildBatch());
+
+ // create batch without estimated size
+ ApacheArrowWriter.Batch data2 = writer.createNewBatch(0);
+ batch2.forEach(r -> r.writeToApacheArrow(columnNames, data2.writeNextRow()));
+ bulkUpsert(data2.buildBatch());
+
+ } catch (IOException ex) {
+ throw new AssertionError("Cannot serialize apache arrow", ex);
+ }
+ }
+
+ int rows1Count = readTable(1, (idx, rs) -> {
+ Assert.assertTrue("Unexpected row index", idx < batch1.size());
+ batch1.get(idx).assertRow(columnNames, idx, rs);
+ });
+ Assert.assertEquals(1000, rows1Count);
+
+ int rows2Count = readTable(2, (idx, rs) -> {
+ Assert.assertTrue("Unexpected row index ", idx < batch2.size());
+ batch2.get(idx).assertRow(columnNames, idx, rs);
+ });
+ Assert.assertEquals(2000, rows2Count);
+ }
+
+ @Test
+ public void writeApacheArrowToColumnShardTest() {
+ // Create table
+ TableDescription table = AllTypesRecord.createTableDescription(true);
+ retryCtx.supplyStatus(s -> s.createTable(tablePath(), table)).join().expectSuccess("Cannot create table");
+
+ Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet());
+
+ ApacheArrowWriter.Schema schema = ApacheArrowWriter.newSchema();
+ table.getColumns().forEach(column -> schema.addColumn(column.getName(), column.getType()));
+
+ // Create batch of 2500 records
+ List batch1 = AllTypesRecord.randomBatch(1, 1, 2500);
+ // Create batch of 5000 records
+ List batch2 = AllTypesRecord.randomBatch(2, 1, 5000);
+
+ try (BufferAllocator allocator = new RootAllocator()) {
+ try (ApacheArrowWriter writer = schema.createWriter(allocator)) {
+ // create batch without estimated size
+ ApacheArrowWriter.Batch data1 = writer.createNewBatch(0);
+ batch1.forEach(r -> r.writeToApacheArrow(columnNames, data1.writeNextRow()));
+ bulkUpsert(data1.buildBatch());
+
+ // create batch with estimated size
+ ApacheArrowWriter.Batch data2 = writer.createNewBatch(5000);
+ batch2.forEach(r -> r.writeToApacheArrow(columnNames, data2.writeNextRow()));
+ bulkUpsert(data2.buildBatch());
+
+ } catch (IOException ex) {
+ throw new AssertionError("Cannot serialize apache arrow", ex);
+ }
+ }
+
+ int rows1Count = readTable(1, (idx, rs) -> {
+ Assert.assertTrue("Unexpected row index", idx < batch1.size());
+ batch1.get(idx).assertRow(columnNames, idx, rs);
+ });
+ Assert.assertEquals(2500, rows1Count);
+
+ int rows2Count = readTable(2, (idx, rs) -> {
+ Assert.assertTrue("Unexpected row index ", idx < batch2.size());
+ batch2.get(idx).assertRow(columnNames, idx, rs);
+ });
+ Assert.assertEquals(5000, rows2Count);
+ }
+}
diff --git a/table/src/test/java/tech/ydb/table/integration/ReadTableTest.java b/table/src/test/java/tech/ydb/table/integration/ReadTableTest.java
index 1d736c3af..d86996b49 100644
--- a/table/src/test/java/tech/ydb/table/integration/ReadTableTest.java
+++ b/table/src/test/java/tech/ydb/table/integration/ReadTableTest.java
@@ -99,7 +99,7 @@ public static void prepareTable() {
}).collect(Collectors.toList());
retryCtx.supplyStatus(session -> session.executeBulkUpsert(tablePath,
- new BulkUpsertData(ProtoValue.toTypedValue(ListType.of(batchType).newValue(batchData)))
+ BulkUpsertData.fromProto(ProtoValue.toTypedValue(ListType.of(batchType).newValue(batchData)))
)).join().expectSuccess("bulk upsert problem in table " + tablePath);
}
diff --git a/table/src/test/java/tech/ydb/table/integration/ValuesReadTest.java b/table/src/test/java/tech/ydb/table/integration/ValuesReadTest.java
index ba91f67f4..89d1faf07 100644
--- a/table/src/test/java/tech/ydb/table/integration/ValuesReadTest.java
+++ b/table/src/test/java/tech/ydb/table/integration/ValuesReadTest.java
@@ -468,7 +468,6 @@ public void date32datetime64timestamp64interval64() {
@Test
public void timestamp64ReadTest() {
- System.out.println(Instant.ofEpochSecond(-4611669897600L));
DataQueryResult result = CTX.supplyResult(
s -> s.executeDataQuery("SELECT "
+ "Timestamp64('-144169-01-01T00:00:00Z') as t1,"
diff --git a/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java b/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java
new file mode 100644
index 000000000..4f48787cb
--- /dev/null
+++ b/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java
@@ -0,0 +1,402 @@
+package tech.ydb.table.query;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.channels.Channels;
+import java.nio.charset.StandardCharsets;
+import java.time.Duration;
+import java.time.Instant;
+import java.time.LocalDate;
+import java.time.LocalDateTime;
+import java.util.UUID;
+
+import com.google.protobuf.ByteString;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.VectorLoader;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.ReadChannel;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.ipc.message.MessageSerializer;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.function.ThrowingRunnable;
+
+import tech.ydb.table.values.DecimalType;
+import tech.ydb.table.values.DecimalValue;
+import tech.ydb.table.values.ListType;
+import tech.ydb.table.values.PrimitiveType;
+
+/**
+ *
+ * @author Aleksandr Gorshenin
+ */
+public class ApacheArrowWriterTest {
+
+ private void assertIllegalArgument(String message, ThrowingRunnable runnable) {
+ IllegalArgumentException ex = Assert.assertThrows(IllegalArgumentException.class, runnable);
+ Assert.assertEquals(message, ex.getMessage());
+ }
+
+ private void assertIllegalState(String message, ThrowingRunnable runnable) {
+ IllegalStateException ex = Assert.assertThrows(IllegalStateException.class, runnable);
+ Assert.assertEquals(message, ex.getMessage());
+ }
+
+ private static RootAllocator allocator;
+
+ @BeforeClass
+ public static void initAllocator() {
+ allocator = new RootAllocator();
+ }
+
+ @AfterClass
+ public static void cleanAllocator() {
+ allocator.close();
+ }
+
+ @Test
+ public void unsupportedTypesTest() {
+ // tz date is not supported
+ assertIllegalArgument("Type TzDate is not supported in ArrowWriter", () -> ApacheArrowWriter.newSchema()
+ .addColumn("col", PrimitiveType.TzDate)
+ .createWriter(allocator));
+
+ // complex types are not supported (except Optional)
+ assertIllegalArgument("Type Int32? is not supported in ArrowWriter", () -> ApacheArrowWriter.newSchema()
+ .addNullableColumn("col", PrimitiveType.Int32.makeOptional())
+ .createWriter(allocator));
+
+ // Non primitive types are not supported (except Decimal)
+ assertIllegalArgument("Type List is not supported in ArrowWriter", () -> ApacheArrowWriter.newSchema()
+ .addColumn("col", ListType.of(PrimitiveType.Int32))
+ .createWriter(allocator));
+ }
+
+ @Test
+ public void invalidColumnTest() {
+ try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema()
+ .addColumn("col", PrimitiveType.Uuid)
+ .createWriter(allocator)) {
+
+ ApacheArrowWriter.Batch batch = writer.createNewBatch(0);
+ ApacheArrowWriter.Row row = batch.writeNextRow();
+ assertIllegalArgument("Column 'col2' not found", () -> row.writeUuid("col2", UUID.randomUUID()));
+ }
+ }
+
+ @Test
+ public void nullableTypeTest() {
+ try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema()
+ .addColumn("col1", PrimitiveType.Uuid)
+ .addColumn("col2", PrimitiveType.Uuid.makeOptional())
+ .addNullableColumn("col3", PrimitiveType.Uuid)
+ .createWriter(allocator)) {
+
+ ApacheArrowWriter.Batch batch = writer.createNewBatch(0);
+ ApacheArrowWriter.Row row = batch.writeNextRow();
+ assertIllegalState("cannot call writeNull, actual type: Uuid", () -> row.writeNull("col1"));
+ row.writeNull("col2"); // success
+ row.writeNull("col3"); // success
+ }
+ }
+
+ @Test
+ public void baseTypeValidationTest() throws IOException {
+ try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema()
+ .addColumn("col1", PrimitiveType.Uuid)
+ .addColumn("col2", PrimitiveType.Int32)
+ .createWriter(allocator)) {
+
+ ApacheArrowWriter.Batch batch = writer.createNewBatch(0);
+ ApacheArrowWriter.Row row = batch.writeNextRow();
+
+ assertIllegalState("cannot call writeBool, actual type: Uuid", () -> row.writeBool("col1", true));
+
+ assertIllegalState("cannot call writeInt8, actual type: Uuid", () -> row.writeInt8("col1", (byte) 0));
+ assertIllegalState("cannot call writeInt16, actual type: Uuid", () -> row.writeInt16("col1", (short) 0));
+ assertIllegalState("cannot call writeInt32, actual type: Uuid", () -> row.writeInt32("col1", 0));
+ assertIllegalState("cannot call writeInt64, actual type: Uuid", () -> row.writeInt64("col1", 0));
+
+ assertIllegalState("cannot call writeUint8, actual type: Uuid", () -> row.writeUint8("col1", 0));
+ assertIllegalState("cannot call writeUint16, actual type: Uuid", () -> row.writeUint16("col1", 0));
+ assertIllegalState("cannot call writeUint32, actual type: Uuid", () -> row.writeUint32("col1", 0));
+ assertIllegalState("cannot call writeUint64, actual type: Uuid", () -> row.writeUint64("col1", 0));
+
+ assertIllegalState("cannot call writeFloat, actual type: Uuid", () -> row.writeFloat("col1", 0));
+ assertIllegalState("cannot call writeDouble, actual type: Uuid", () -> row.writeDouble("col1", 0));
+
+ assertIllegalState("cannot call writeText, actual type: Uuid", () -> row.writeText("col1", ""));
+ assertIllegalState("cannot call writeJson, actual type: Uuid", () -> row.writeJson("col1", ""));
+ assertIllegalState("cannot call writeJsonDocument, actual type: Uuid",
+ () -> row.writeJsonDocument("col1", ""));
+
+ assertIllegalState("cannot call writeBytes, actual type: Uuid",
+ () -> row.writeBytes("col1", new byte[0]));
+ assertIllegalState("cannot call writeYson, actual type: Uuid",
+ () -> row.writeYson("col1", new byte[0]));
+
+ assertIllegalState("cannot call writeDate, actual type: Uuid",
+ () -> row.writeDate("col1", LocalDate.ofEpochDay(0)));
+ assertIllegalState("cannot call writeDatetime, actual type: Uuid",
+ () -> row.writeDatetime("col1", LocalDateTime.now()));
+ assertIllegalState("cannot call writeTimestamp, actual type: Uuid",
+ () -> row.writeTimestamp("col1", Instant.ofEpochSecond(0)));
+ assertIllegalState("cannot call writeInterval, actual type: Uuid",
+ () -> row.writeInterval("col1", Duration.ZERO));
+
+ assertIllegalState("cannot call writeDate32, actual type: Uuid",
+ () -> row.writeDate32("col1", LocalDate.ofEpochDay(0)));
+ assertIllegalState("cannot call writeDatetime64, actual type: Uuid",
+ () -> row.writeDatetime64("col1", LocalDateTime.now()));
+ assertIllegalState("cannot call writeTimestamp64, actual type: Uuid",
+ () -> row.writeTimestamp64("col1", Instant.ofEpochSecond(0)));
+ assertIllegalState("cannot call writeInterval64, actual type: Uuid",
+ () -> row.writeInterval64("col1", Duration.ZERO));
+
+ // second column
+ assertIllegalState("cannot call writeDecimal, actual type: Int32",
+ () -> row.writeDecimal("col2", DecimalType.getDefault().newValue(0)));
+ assertIllegalState("cannot call writeUuid, actual type: Int32",
+ () -> row.writeUuid("col2", UUID.randomUUID()));
+ }
+ }
+
+ @Test
+ public void tinyIntVectorTest() throws IOException {
+ try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema()
+ .addColumn("c1", PrimitiveType.Bool)
+ .addNullableColumn("c2", PrimitiveType.Int8)
+ .addNullableColumn("c3", PrimitiveType.Uint8)
+ .createWriter(allocator)) {
+
+ ApacheArrowWriter.Batch batch = writer.createNewBatch(0);
+ ApacheArrowWriter.Row row = batch.writeNextRow();
+
+ assertIllegalState("cannot call writeInt8, actual type: Bool", () -> row.writeInt8("c1", (byte) 0));
+ assertIllegalState("cannot call writeUint8, actual type: Int8", () -> row.writeUint8("c2", 0));
+ assertIllegalState("cannot call writeBool, actual type: Uint8", () -> row.writeBool("c3", false));
+
+ BulkUpsertArrowData data = batch.buildBatch();
+
+ Schema schema = readApacheArrowSchema(data.getSchema());
+ Assert.assertEquals("Schema",
+ schema.toString());
+ }
+ }
+
+ @Test
+ public void smallIntVectorTest() throws IOException {
+ try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema()
+ .addColumn("c1", PrimitiveType.Date)
+ .addNullableColumn("c2", PrimitiveType.Int16)
+ .addNullableColumn("c3", PrimitiveType.Uint16)
+ .createWriter(allocator)) {
+
+ ApacheArrowWriter.Batch batch = writer.createNewBatch(0);
+ ApacheArrowWriter.Row row = batch.writeNextRow();
+
+ assertIllegalState("cannot call writeUint16, actual type: Date", () -> row.writeUint16("c1", 0));
+ assertIllegalState("cannot call writeDate, actual type: Int16",
+ () -> row.writeDate("c2", LocalDate.ofEpochDay(0)));
+ assertIllegalState("cannot call writeInt16, actual type: Uint16", () -> row.writeInt16("c3", (short) 0));
+
+ BulkUpsertArrowData data = batch.buildBatch();
+
+ Schema schema = readApacheArrowSchema(data.getSchema());
+ Assert.assertEquals("Schema",
+ schema.toString());
+ }
+ }
+
+ @Test
+ public void intVectorTest() throws IOException {
+ try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema()
+ .addColumn("c1", PrimitiveType.Int32)
+ .addColumn("c2", PrimitiveType.Uint32)
+ .addNullableColumn("c3", PrimitiveType.Date32)
+ .addNullableColumn("c4", PrimitiveType.Datetime)
+ .createWriter(allocator)) {
+
+ ApacheArrowWriter.Batch batch = writer.createNewBatch(0);
+ ApacheArrowWriter.Row row = batch.writeNextRow();
+
+ assertIllegalState("cannot call writeUint32, actual type: Int32", () -> row.writeUint32("c1", 0));
+ assertIllegalState("cannot call writeDate32, actual type: Uint32",
+ () -> row.writeDate32("c2", LocalDate.ofEpochDay(0)));
+ assertIllegalState("cannot call writeDatetime, actual type: Date32",
+ () -> row.writeDatetime("c3", LocalDateTime.now()));
+ assertIllegalState("cannot call writeInt32, actual type: Datetime", () -> row.writeInt32("c4", 0));
+
+ BulkUpsertArrowData data = batch.buildBatch();
+
+ Schema schema = readApacheArrowSchema(data.getSchema());
+ Assert.assertEquals("Schema", schema.toString());
+ }
+ }
+
+ @Test
+ public void bigIntVectorTest() throws IOException {
+ try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema()
+ .addColumn("c1", PrimitiveType.Int64)
+ .addColumn("c2", PrimitiveType.Uint64)
+ .addNullableColumn("c3", PrimitiveType.Datetime64)
+ .addNullableColumn("c4", PrimitiveType.Timestamp)
+ .addNullableColumn("c5", PrimitiveType.Timestamp64)
+ .addNullableColumn("c6", PrimitiveType.Interval)
+ .addNullableColumn("c7", PrimitiveType.Interval64)
+ .createWriter(allocator)) {
+
+ ApacheArrowWriter.Batch batch = writer.createNewBatch(0);
+ ApacheArrowWriter.Row row = batch.writeNextRow();
+
+ assertIllegalState("cannot call writeUint64, actual type: Int64", () -> row.writeUint64("c1", 0));
+ assertIllegalState("cannot call writeDatetime64, actual type: Uint64",
+ () -> row.writeDatetime64("c2", LocalDateTime.now()));
+ assertIllegalState("cannot call writeTimestamp, actual type: Datetime64",
+ () -> row.writeTimestamp("c3", Instant.now()));
+ assertIllegalState("cannot call writeTimestamp64, actual type: Timestamp",
+ () -> row.writeTimestamp64("c4", Instant.now()));
+ assertIllegalState("cannot call writeInterval, actual type: Timestamp64",
+ () -> row.writeInterval("c5", Duration.ZERO));
+ assertIllegalState("cannot call writeInterval64, actual type: Interval",
+ () -> row.writeInterval64("c6", Duration.ZERO));
+ assertIllegalState("cannot call writeInt64, actual type: Interval64",
+ () -> row.writeInt64("c7", 0));
+
+ BulkUpsertArrowData data = batch.buildBatch();
+
+ Schema schema = readApacheArrowSchema(data.getSchema());
+ Assert.assertEquals("Schema",
+ schema.toString());
+ }
+ }
+
+ @Test
+ public void varCharVectorTest() throws IOException {
+ try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema()
+ .addColumn("c1", PrimitiveType.Text)
+ .addNullableColumn("c2", PrimitiveType.Json)
+ .addColumn("c3", PrimitiveType.JsonDocument)
+ .createWriter(allocator)) {
+
+ ApacheArrowWriter.Batch batch = writer.createNewBatch(0);
+ ApacheArrowWriter.Row row = batch.writeNextRow();
+
+ assertIllegalState("cannot call writeJson, actual type: Text", () -> row.writeJson("c1", ""));
+ assertIllegalState("cannot call writeJsonDocument, actual type: Json", () -> row.writeJsonDocument("c2", ""));
+ assertIllegalState("cannot call writeText, actual type: JsonDocument", () -> row.writeText("c3", ""));
+
+ BulkUpsertArrowData data = batch.buildBatch();
+
+ Schema schema = readApacheArrowSchema(data.getSchema());
+ Assert.assertEquals("Schema", schema.toString());
+ }
+ }
+
+ @Test
+ public void varBinaryVectorTest() throws IOException {
+ try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema()
+ .addColumn("c1", PrimitiveType.Bytes)
+ .addNullableColumn("c2", PrimitiveType.Yson)
+ .createWriter(allocator)) {
+
+ ApacheArrowWriter.Batch batch = writer.createNewBatch(0);
+ ApacheArrowWriter.Row row = batch.writeNextRow();
+
+ assertIllegalState("cannot call writeYson, actual type: Bytes", () -> row.writeYson("c1", new byte[0]));
+ assertIllegalState("cannot call writeBytes, actual type: Yson", () -> row.writeBytes("c2", new byte[0]));
+
+ BulkUpsertArrowData data = batch.buildBatch();
+
+ Schema schema = readApacheArrowSchema(data.getSchema());
+ Assert.assertEquals("Schema", schema.toString());
+ }
+ }
+
+ @Test
+ public void fixedSizeBinaryVectorTest() throws IOException {
+ DecimalValue dv = DecimalType.getDefault().newValue(0);
+ UUID uv = UUID.randomUUID();
+
+ try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema()
+ .addColumn("c1", PrimitiveType.Uuid)
+ .addNullableColumn("c2", DecimalType.getDefault())
+ .createWriter(allocator)) {
+
+ ApacheArrowWriter.Batch batch = writer.createNewBatch(0);
+ ApacheArrowWriter.Row row = batch.writeNextRow();
+
+ assertIllegalState("cannot call writeDecimal, actual type: Uuid", () -> row.writeDecimal("c1", dv));
+ assertIllegalState("cannot call writeUuid, actual type: Decimal(22, 9)", () -> row.writeUuid("c2", uv));
+
+ BulkUpsertArrowData data = batch.buildBatch();
+
+ Schema schema = readApacheArrowSchema(data.getSchema());
+ Assert.assertEquals("Schema", schema.toString());
+ }
+ }
+
+ private BulkUpsertArrowData createSimpleBatch() throws IOException {
+ try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema()
+ .addColumn("pk", PrimitiveType.Int32)
+ .addNullableColumn("value", PrimitiveType.Text)
+ .createWriter(allocator)) {
+
+ ApacheArrowWriter.Batch batch = writer.createNewBatch(2);
+ ApacheArrowWriter.Row row1 = batch.writeNextRow();
+ ApacheArrowWriter.Row row2 = batch.writeNextRow();
+
+ row1.writeInt32("pk", 1);
+ row1.writeText("value", "value-1");
+
+ row2.writeInt32("pk", 2);
+ row2.writeText("value", "значение-2");
+
+ return batch.buildBatch();
+ }
+ }
+
+ @Test
+ public void readArrayBatchTest() throws IOException {
+ BulkUpsertArrowData data = createSimpleBatch();
+
+ Schema schema = readApacheArrowSchema(data.getSchema());
+ try (VectorSchemaRoot vector = VectorSchemaRoot.create(schema, allocator)) {
+ // readApacheArrowBatch
+ try (InputStream is = data.getData().newInput()) {
+ try (ReadChannel channel = new ReadChannel(Channels.newChannel(is))) {
+ try (ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(channel, allocator)) {
+ VectorLoader loader = new VectorLoader(vector);
+ loader.load(batch);
+ }
+ }
+ }
+
+ IntVector pk = (IntVector) vector.getVector("pk");
+ VarCharVector value = (VarCharVector) vector.getVector("value");
+
+ Assert.assertEquals(2, vector.getRowCount());
+ Assert.assertEquals(1, pk.get(0));
+ Assert.assertEquals(2, pk.get(1));
+ Assert.assertEquals("value-1", new String(value.get(0), StandardCharsets.UTF_8));
+ Assert.assertEquals("значение-2", new String(value.get(1), StandardCharsets.UTF_8));
+ }
+ }
+
+ private static Schema readApacheArrowSchema(ByteString bytes) {
+ try (InputStream is = bytes.newInput()) {
+ try (ReadChannel channel = new ReadChannel(Channels.newChannel(is))) {
+ return MessageSerializer.deserializeSchema(channel);
+ }
+ } catch (IOException ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+}