diff --git a/pom.xml b/pom.xml index 9207b63dc..f47a2634c 100644 --- a/pom.xml +++ b/pom.xml @@ -40,8 +40,9 @@ 2.25.3 2.13.2 5.21.0 - 2.0.3 + + 17.0.0 @@ -88,6 +89,17 @@ ${gson.version} + + org.apache.arrow + arrow-vector + ${apache.arrow.version} + + + org.apache.arrow + arrow-memory-netty + ${apache.arrow.version} + + junit @@ -238,11 +250,14 @@ jdk8-bootstrap - [9 + [9,) argLine + + + diff --git a/table/pom.xml b/table/pom.xml index 4c5dc1f58..a74e625d1 100644 --- a/table/pom.xml +++ b/table/pom.xml @@ -30,17 +30,26 @@ + org.apache.arrow + arrow-vector + true + + + org.apache.arrow + arrow-memory-netty + true + + + junit junit test - tech.ydb.test ydb-junit4-support test - org.apache.logging.log4j log4j-slf4j2-impl @@ -57,9 +66,23 @@ true ydbplatform/local-ydb:trunk + enable_columnshard_bool + + + + jdk8-bootstrap + + [9,) + + + + --add-opens=java.base/java.nio=ALL-UNNAMED + + + diff --git a/table/src/main/java/tech/ydb/table/Session.java b/table/src/main/java/tech/ydb/table/Session.java index 7793c972e..d4da4e7a1 100644 --- a/table/src/main/java/tech/ydb/table/Session.java +++ b/table/src/main/java/tech/ydb/table/Session.java @@ -199,7 +199,7 @@ default CompletableFuture executeScanQuery(String query, Params params, CompletableFuture> keepAlive(KeepAliveSessionSettings settings); default CompletableFuture executeBulkUpsert(String tablePath, ListValue rows, BulkUpsertSettings settings) { - return executeBulkUpsert(tablePath, new BulkUpsertData(rows), settings); + return executeBulkUpsert(tablePath, BulkUpsertData.fromRows(rows), settings); } CompletableFuture executeBulkUpsert(String tablePath, BulkUpsertData data, BulkUpsertSettings settings); diff --git a/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java b/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java new file mode 100644 index 000000000..2fa6b4125 --- /dev/null +++ b/table/src/main/java/tech/ydb/table/query/ApacheArrowWriter.java @@ -0,0 +1,868 @@ +package tech.ydb.table.query; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.function.Function; + +import com.google.protobuf.ByteString; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; + +import tech.ydb.table.utils.LittleEndian; +import tech.ydb.table.values.DecimalValue; +import tech.ydb.table.values.PrimitiveType; +import tech.ydb.table.values.Type; +import tech.ydb.table.values.proto.ProtoValue; + +/** + * + * @author Aleksandr Gorshenin + */ +public class ApacheArrowWriter implements AutoCloseable { + public interface Batch { + Row writeNextRow(); + BulkUpsertArrowData buildBatch() throws IOException; + } + + public interface Row { + void writeNull(String column); + + void writeBool(String column, boolean value); + + void writeInt8(String column, byte value); + void writeInt16(String column, short value); + void writeInt32(String column, int value); + void writeInt64(String column, long value); + + void writeUint8(String column, int value); + void writeUint16(String column, int value); + void writeUint32(String column, long value); + void writeUint64(String column, long value); + + void writeFloat(String column, float value); + void writeDouble(String column, double value); + + void writeText(String column, String text); + void writeJson(String column, String json); + void writeJsonDocument(String column, String jsonDocument); + + void writeBytes(String column, byte[] bytes); + void writeYson(String column, byte[] yson); + + void writeUuid(String column, UUID uuid); + + void writeDate(String column, LocalDate date); + void writeDatetime(String column, LocalDateTime datetime); + void writeTimestamp(String column, Instant instant); + void writeInterval(String column, Duration interval); + + void writeDate32(String column, LocalDate date32); + void writeDatetime64(String column, LocalDateTime datetime64); + void writeTimestamp64(String column, Instant instant64); + void writeInterval64(String column, Duration interval64); + + void writeDecimal(String column, DecimalValue value); + } + + private final VectorSchemaRoot vsr; + private final Map> columns = new HashMap<>(); + + private ApacheArrowWriter(BufferAllocator allocator, List columnsList) { + FieldVector[] vectors = new FieldVector[columnsList.size()]; + for (int idx = 0; idx < columnsList.size(); idx += 1) { + ColumnInfo column = columnsList.get(idx); + Column vector = column.createVector(allocator); + vectors[idx] = vector.vector; + columns.put(column.name, vector); + } + + this.vsr = VectorSchemaRoot.of(vectors); + } + + public Batch createNewBatch(int estimatedRowsCount) { + // reset all + for (Column column: columns.values()) { + column.allocateNew(estimatedRowsCount); + } + return new BatchImpl(); + } + + @Override + public void close() { + vsr.close(); + } + + private class BatchImpl implements Batch { + private int rowIndex = 0; + + @Override + public Row writeNextRow() { + return new RowImpl(rowIndex++); + } + + @Override + public BulkUpsertArrowData buildBatch() throws IOException { + vsr.setRowCount(rowIndex); + return new BulkUpsertArrowData(serializeSchema(), serializeBatch()); + } + + private ByteString serializeSchema() throws IOException { + try (ByteString.Output out = ByteString.newOutput()) { + try (WriteChannel channel = new WriteChannel(Channels.newChannel(out))) { + MessageSerializer.serialize(channel, vsr.getSchema()); + return out.toByteString(); + } + } + } + + private ByteString serializeBatch() throws IOException { + try (ByteString.Output out = ByteString.newOutput()) { + try (WriteChannel channel = new WriteChannel(Channels.newChannel(out))) { + VectorUnloader loader = new VectorUnloader(vsr); + try (ArrowRecordBatch batch = loader.getRecordBatch()) { + MessageSerializer.serialize(channel, batch); + return out.toByteString(); + } + } + } + } + } + + private class RowImpl implements Row { + private final int rowIndex; + + RowImpl(int rowIndex) { + this.rowIndex = rowIndex; + } + + private Column find(String column) { + Column vector = columns.get(column); + if (vector == null) { + throw new IllegalArgumentException("Column '" + column + "' not found"); + } + return vector; + } + + @Override + public void writeNull(String column) { + find(column).writeNull(rowIndex); + } + + @Override + public void writeBool(String column, boolean value) { + find(column).writeBool(rowIndex, value); + } + + @Override + public void writeInt8(String column, byte value) { + find(column).writeInt8(rowIndex, value); + } + + @Override + public void writeInt16(String column, short value) { + find(column).writeInt16(rowIndex, value); + } + + @Override + public void writeInt32(String column, int value) { + find(column).writeInt32(rowIndex, value); + } + + @Override + public void writeInt64(String column, long value) { + find(column).writeInt64(rowIndex, value); + } + + @Override + public void writeUint8(String column, int value) { + find(column).writeUint8(rowIndex, value); + } + + @Override + public void writeUint16(String column, int value) { + find(column).writeUint16(rowIndex, value); + } + + @Override + public void writeUint32(String column, long value) { + find(column).writeUint32(rowIndex, value); + } + + @Override + public void writeUint64(String column, long value) { + find(column).writeUint64(rowIndex, value); + } + + @Override + public void writeFloat(String column, float value) { + find(column).writeFloat(rowIndex, value); + } + + @Override + public void writeDouble(String column, double value) { + find(column).writeDouble(rowIndex, value); + } + + @Override + public void writeText(String column, String text) { + find(column).writeText(rowIndex, text); + } + + @Override + public void writeJson(String column, String json) { + find(column).writeJson(rowIndex, json); + } + + @Override + public void writeJsonDocument(String column, String jsonDocument) { + find(column).writeJsonDocument(rowIndex, jsonDocument); + } + + @Override + public void writeBytes(String column, byte[] bytes) { + find(column).writeBytes(rowIndex, bytes); + } + + @Override + public void writeYson(String column, byte[] yson) { + find(column).writeYson(rowIndex, yson); + } + + @Override + public void writeUuid(String column, UUID uuid) { + find(column).writeUuid(rowIndex, uuid); + } + + @Override + public void writeDate(String column, LocalDate date) { + find(column).writeDate(rowIndex, date); + } + + @Override + public void writeDatetime(String column, LocalDateTime datetime) { + find(column).writeDatetime(rowIndex, datetime); + } + + @Override + public void writeTimestamp(String column, Instant instant) { + find(column).writeTimestamp(rowIndex, instant); + } + + @Override + public void writeInterval(String column, Duration interval) { + find(column).writeInterval(rowIndex, interval); + } + + @Override + public void writeDate32(String column, LocalDate date32) { + find(column).writeDate32(rowIndex, date32); + } + + @Override + public void writeDatetime64(String column, LocalDateTime datetime64) { + find(column).writeDatetime64(rowIndex, datetime64); + } + + @Override + public void writeTimestamp64(String column, Instant instant64) { + find(column).writeTimestamp64(rowIndex, instant64); + } + + @Override + public void writeInterval64(String column, Duration interval64) { + find(column).writeInterval64(rowIndex, interval64); + } + + @Override + public void writeDecimal(String column, DecimalValue value) { + find(column).writeDecimal(rowIndex, value); + } + } + + private abstract static class Column { + protected final Field field; + protected final Type type; + protected final T vector; + + Column(Field field, Type type, T vector) { + this.field = field; + this.type = type; + this.vector = vector; + } + + protected IllegalStateException error(String method) { + return new IllegalStateException("cannot call " + method + ", actual type: " + type); + } + + public abstract void allocateNew(int estimated); + + void writeNull(int rowIndex) { + if (field.isNullable()) { + vector.setNull(rowIndex); + } else { + throw error("writeNull"); + } + } + + void writeBool(int rowIndex, boolean value) { + throw error("writeBool"); + } + + void writeInt8(int rowIndex, byte value) { + throw error("writeInt8"); + } + + void writeInt16(int rowIndex, short value) { + throw error("writeInt16"); + } + + void writeInt32(int rowIndex, int value) { + throw error("writeInt32"); + } + + void writeInt64(int rowIndex, long value) { + throw error("writeInt64"); + } + + void writeUint8(int rowIndex, int value) { + throw error("writeUint8"); + } + + void writeUint16(int rowIndex, int value) { + throw error("writeUint16"); + } + + void writeUint32(int rowIndex, long value) { + throw error("writeUint32"); + } + + void writeUint64(int rowIndex, long value) { + throw error("writeUint64"); + } + + void writeFloat(int rowIndex, float value) { + throw error("writeFloat"); + } + + void writeDouble(int rowIndex, double value) { + throw error("writeDouble"); + } + + void writeText(int rowIndex, String text) { + throw error("writeText"); + } + + void writeJson(int rowIndex, String json) { + throw error("writeJson"); + } + + void writeJsonDocument(int rowIndex, String jsonDocument) { + throw error("writeJsonDocument"); + } + + void writeBytes(int rowIndex, byte[] bytes) { + throw error("writeBytes"); + } + + void writeYson(int rowIndex, byte[] yson) { + throw error("writeYson"); + } + + void writeUuid(int rowIndex, UUID yson) { + throw error("writeUuid"); + } + + void writeDate(int rowIndex, LocalDate date) { + throw error("writeDate"); + } + + void writeDatetime(int rowIndex, LocalDateTime datetime) { + throw error("writeDatetime"); + } + + void writeTimestamp(int rowIndex, Instant instant) { + throw error("writeTimestamp"); + } + + void writeInterval(int rowIndex, Duration interval) { + throw error("writeInterval"); + } + + void writeDate32(int rowIndex, LocalDate date32) { + throw error("writeDate32"); + } + + void writeDatetime64(int rowIndex, LocalDateTime datetime64) { + throw error("writeDatetime64"); + } + + void writeTimestamp64(int rowIndex, Instant instant64) { + throw error("writeTimestamp64"); + } + + void writeInterval64(int rowIndex, Duration interval64) { + throw error("writeInterval64"); + } + + void writeDecimal(int rowIndex, DecimalValue value) { + throw error("writeDecimal"); + } + } + + private static class FixedWidthColumn extends Column { + + FixedWidthColumn(Field field, Type type, T vector) { + super(field, type, vector); + } + + @Override + public void allocateNew(int estimated) { + vector.allocateNew(estimated); + } + } + + private static class VariableWidthColumn extends Column { + + VariableWidthColumn(Field field, Type type, T vector) { + super(field, type, vector); + } + + @Override + public void allocateNew(int estimated) { + vector.allocateNew(estimated); + } + } + + private static class TinyIntColumn extends FixedWidthColumn { + + TinyIntColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new TinyIntVector(field, allocator)); + } + + @Override + void writeBool(int rowIndex, boolean value) { + if (type != PrimitiveType.Bool) { + throw error("writeBool"); + } + vector.setSafe(rowIndex, value ? 1 : 0); + } + + @Override + void writeInt8(int rowIndex, byte value) { + if (type != PrimitiveType.Int8) { + throw error("writeInt8"); + } + vector.setSafe(rowIndex, value); + } + + @Override + void writeUint8(int rowIndex, int value) { + if (type != PrimitiveType.Uint8) { + throw error("writeUint8"); + } + vector.setSafe(rowIndex, value); + } + } + + private static class SmallIntColumn extends FixedWidthColumn { + + SmallIntColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new SmallIntVector(field, allocator)); + } + + @Override + void writeInt16(int rowIndex, short value) { + if (type != PrimitiveType.Int16) { + throw error("writeInt16"); + } + vector.setSafe(rowIndex, value); + } + + @Override + void writeUint16(int rowIndex, int value) { + if (type != PrimitiveType.Uint16) { + throw error("writeUint16"); + } + vector.setSafe(rowIndex, value); + } + + @Override + void writeDate(int rowIndex, LocalDate date) { + if (type != PrimitiveType.Date) { + throw error("writeDate"); + } + vector.setSafe(rowIndex, (int) date.toEpochDay()); + } + } + + private static class IntColumn extends FixedWidthColumn { + + IntColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new IntVector(field, allocator)); + } + + @Override + void writeInt32(int rowIndex, int value) { + if (type != PrimitiveType.Int32) { + throw error("writeInt32"); + } + vector.setSafe(rowIndex, value); + } + + @Override + void writeUint32(int rowIndex, long value) { + if (type != PrimitiveType.Uint32) { + throw error("writeUint32"); + } + vector.setSafe(rowIndex, (int) value); + } + + @Override + void writeDate32(int rowIndex, LocalDate date) { + if (type != PrimitiveType.Date32) { + throw error("writeDate32"); + } + vector.setSafe(rowIndex, (int) date.toEpochDay()); + } + + @Override + void writeDatetime(int rowIndex, LocalDateTime datetime) { + if (type != PrimitiveType.Datetime) { + throw error("writeDatetime"); + } + vector.setSafe(rowIndex, (int) datetime.toEpochSecond(ZoneOffset.UTC)); + } + } + + private static class BigIntColumn extends FixedWidthColumn { + + BigIntColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new BigIntVector(field, allocator)); + } + + @Override + void writeInt64(int rowIndex, long value) { + if (type != PrimitiveType.Int64) { + throw error("writeInt64"); + } + vector.setSafe(rowIndex, value); + } + + @Override + void writeUint64(int rowIndex, long value) { + if (type != PrimitiveType.Uint64) { + throw error("writeUint64"); + } + vector.setSafe(rowIndex, value); + } + + @Override + void writeDatetime64(int rowIndex, LocalDateTime datetime) { + if (type != PrimitiveType.Datetime64) { + throw error("writeDatetime64"); + } + vector.setSafe(rowIndex, datetime.toEpochSecond(ZoneOffset.UTC)); + } + + @Override + void writeTimestamp(int rowIndex, Instant value) { + if (type != PrimitiveType.Timestamp) { + throw error("writeTimestamp"); + } + long micros = value.getEpochSecond() * 1000000L + value.getNano() / 1000; + vector.setSafe(rowIndex, micros); + } + + @Override + void writeTimestamp64(int rowIndex, Instant value) { + if (type != PrimitiveType.Timestamp64) { + throw error("writeTimestamp64"); + } + long micros = value.getEpochSecond() * 1000000L + value.getNano() / 1000; + vector.setSafe(rowIndex, micros); + } + + @Override + void writeInterval(int rowIndex, Duration duration) { + if (type != PrimitiveType.Interval) { + throw error("writeInterval"); + } + long micros = duration.getSeconds() * 1000000L + duration.getNano() / 1000; + vector.setSafe(rowIndex, micros); + } + + @Override + void writeInterval64(int rowIndex, Duration duration) { + if (type != PrimitiveType.Interval64) { + throw error("writeInterval64"); + } + long micros = duration.getSeconds() * 1000000L + duration.getNano() / 1000; + vector.setSafe(rowIndex, micros); + } + } + + private static class FloatColumn extends FixedWidthColumn { + + FloatColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new Float4Vector(field, allocator)); + } + + @Override + void writeFloat(int rowIndex, float value) { + vector.setSafe(rowIndex, value); + } + } + + private static class DoubleColumn extends FixedWidthColumn { + + DoubleColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new Float8Vector(field, allocator)); + } + + @Override + void writeDouble(int rowIndex, double value) { + vector.setSafe(rowIndex, value); + } + } + + private static class VarCharColumn extends VariableWidthColumn { + + VarCharColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new VarCharVector(field, allocator)); + } + + @Override + void writeText(int rowIndex, String text) { + if (type != PrimitiveType.Text) { + throw error("writeText"); + } + vector.setSafe(rowIndex, text.getBytes(StandardCharsets.UTF_8)); + } + + @Override + void writeJson(int rowIndex, String json) { + if (type != PrimitiveType.Json) { + throw error("writeJson"); + } + vector.setSafe(rowIndex, json.getBytes(StandardCharsets.UTF_8)); + } + + @Override + void writeJsonDocument(int rowIndex, String jsonDocument) { + if (type != PrimitiveType.JsonDocument) { + throw error("writeJsonDocument"); + } + vector.setSafe(rowIndex, jsonDocument.getBytes(StandardCharsets.UTF_8)); + } + } + + private static class VarBinaryColumn extends VariableWidthColumn { + + VarBinaryColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new VarBinaryVector(field, allocator)); + } + + @Override + void writeBytes(int rowIndex, byte[] bytes) { + if (type != PrimitiveType.Bytes) { + throw error("writeBytes"); + } + vector.setSafe(rowIndex, bytes); + } + + @Override + void writeYson(int rowIndex, byte[] yson) { + if (type != PrimitiveType.Yson) { + throw error("writeYson"); + } + vector.setSafe(rowIndex, yson); + } + } + + private static class FixedBinaryColumn extends FixedWidthColumn { + + FixedBinaryColumn(Type type, BufferAllocator allocator, Field field) { + super(field, type, new FixedSizeBinaryVector(field, allocator)); + } + + /** + * @see ProtoValue#newUuid(java.util.UUID) + */ + @Override + void writeUuid(int rowIndex, UUID uuid) { + if (type != PrimitiveType.Uuid) { + throw error("writeUuid"); + } + + ByteBuffer buf = ByteBuffer.allocate(16); + + long msb = uuid.getMostSignificantBits(); + long timeLow = (msb & 0xffffffff00000000L) >>> 32; + long timeMid = (msb & 0x00000000ffff0000L) << 16; + long timeHighAndVersion = (msb & 0x000000000000ffffL) << 48; + buf.putLong(LittleEndian.bswap(timeLow | timeMid | timeHighAndVersion)); + buf.putLong(uuid.getLeastSignificantBits()); + + vector.setSafe(rowIndex, buf.array()); + } + + @Override + void writeDecimal(int rowIndex, DecimalValue value) { + if (type.getKind() != Type.Kind.DECIMAL) { + throw error("writeDecimal"); + } + + ByteBuffer buf = ByteBuffer.allocate(16); + buf.putLong(LittleEndian.bswap(value.getLow())); + buf.putLong(LittleEndian.bswap(value.getHigh())); + + vector.setSafe(rowIndex, buf.array()); + } + } + + private static class ColumnInfo { + private final String name; + private final Type type; + + ColumnInfo(String name, Type type) { + this.name = name; + this.type = type; + } + + private Column createVector(BufferAllocator allocator) { + if (type.getKind() == Type.Kind.OPTIONAL) { + return createColumnVector(allocator, type.unwrapOptional(), t -> Field.nullable(name, t)); + } + + return createColumnVector(allocator, type, t -> Field.notNullable(name, t)); + } + } + + private static Column createColumnVector(BufferAllocator allocator, Type type, Function gen) { + if (type.getKind() == Type.Kind.DECIMAL) { + return new FixedBinaryColumn(type, allocator, gen.apply(new ArrowType.FixedSizeBinary(16))); + } + + if (type.getKind() == Type.Kind.PRIMITIVE) { + switch ((PrimitiveType) type) { + case Bool: + return new TinyIntColumn(type, allocator, gen.apply(new ArrowType.Int(8, false))); + + case Int8: + return new TinyIntColumn(type, allocator, gen.apply(new ArrowType.Int(8, true))); + case Int16: + return new SmallIntColumn(type, allocator, gen.apply(new ArrowType.Int(16, true))); + case Int32: + return new IntColumn(type, allocator, gen.apply(new ArrowType.Int(32, true))); + case Int64: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true))); + + case Uint8: + return new TinyIntColumn(type, allocator, gen.apply(new ArrowType.Int(8, false))); + case Uint16: + return new SmallIntColumn(type, allocator, gen.apply(new ArrowType.Int(16, false))); + case Uint32: + return new IntColumn(type, allocator, gen.apply(new ArrowType.Int(32, false))); + case Uint64: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, false))); + + case Float: + return new FloatColumn(type, allocator, gen.apply( + new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + )); + case Double: + return new DoubleColumn(type, allocator, gen.apply( + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + )); + + case Text: + case Json: + case JsonDocument: + return new VarCharColumn(type, allocator, gen.apply(new ArrowType.Utf8())); + + case Bytes: + case Yson: + return new VarBinaryColumn(type, allocator, gen.apply(new ArrowType.Binary())); + + case Uuid: + return new FixedBinaryColumn(type, allocator, gen.apply(new ArrowType.FixedSizeBinary(16))); + + case Date: + return new SmallIntColumn(type, allocator, gen.apply(new ArrowType.Int(16, false))); + case Datetime: + return new IntColumn(type, allocator, gen.apply(new ArrowType.Int(32, false))); + case Timestamp: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true))); + case Interval: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Duration(TimeUnit.MICROSECOND))); + + case Date32: + return new IntColumn(type, allocator, gen.apply(new ArrowType.Int(32, true))); + case Datetime64: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true))); + case Timestamp64: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true))); + case Interval64: + return new BigIntColumn(type, allocator, gen.apply(new ArrowType.Int(64, true))); + + default: + break; + } + } + + throw new IllegalArgumentException("Type " + type + " is not supported in ArrowWriter"); + } + + public static Schema newSchema() { + return new Schema(); + } + + public static class Schema { + private final List columns = new ArrayList<>(); + + public Schema addColumn(String name, Type type) { + this.columns.add(new ColumnInfo(name, type)); + return this; + } + + public Schema addNullableColumn(String name, Type type) { + return addColumn(name, type.makeOptional()); + } + + public ApacheArrowWriter createWriter(BufferAllocator allocator) { + return new ApacheArrowWriter(allocator, columns); + } + } +} diff --git a/table/src/main/java/tech/ydb/table/query/BulkUpsertArrowData.java b/table/src/main/java/tech/ydb/table/query/BulkUpsertArrowData.java new file mode 100644 index 000000000..1b2d23e8b --- /dev/null +++ b/table/src/main/java/tech/ydb/table/query/BulkUpsertArrowData.java @@ -0,0 +1,35 @@ +package tech.ydb.table.query; + +import com.google.protobuf.ByteString; + +import tech.ydb.proto.formats.YdbFormats; +import tech.ydb.proto.table.YdbTable; + +/** + * + * @author Aleksandr Gorshenin + */ +public class BulkUpsertArrowData implements BulkUpsertData { + private final ByteString schema; + private final ByteString data; + + public BulkUpsertArrowData(ByteString schema, ByteString data) { + this.schema = schema; + this.data = data; + } + + public ByteString getSchema() { + return schema; + } + + public ByteString getData() { + return data; + } + + @Override + public void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder) { + builder.setArrowBatchSettings( + YdbFormats.ArrowBatchSettings.newBuilder().setSchema(schema).build() + ).setData(data); + } +} diff --git a/table/src/main/java/tech/ydb/table/query/BulkUpsertData.java b/table/src/main/java/tech/ydb/table/query/BulkUpsertData.java index 73d170ffc..96dee3b1d 100644 --- a/table/src/main/java/tech/ydb/table/query/BulkUpsertData.java +++ b/table/src/main/java/tech/ydb/table/query/BulkUpsertData.java @@ -9,21 +9,14 @@ * * @author Aleksandr Gorshenin */ -public class BulkUpsertData { - private final ValueProtos.TypedValue rows; +public interface BulkUpsertData { + void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder); - public BulkUpsertData(ListValue rows) { - this.rows = ValueProtos.TypedValue.newBuilder() - .setType(rows.getType().toPb()) - .setValue(rows.toPb()) - .build(); + static BulkUpsertData fromRows(ListValue list) { + return new BulkUpsertProtoData(list); } - public BulkUpsertData(ValueProtos.TypedValue rows) { - this.rows = rows; - } - - public void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder) { - builder.setRows(rows); + static BulkUpsertData fromProto(ValueProtos.TypedValue rows) { + return new BulkUpsertProtoData(rows); } } diff --git a/table/src/main/java/tech/ydb/table/query/BulkUpsertProtoData.java b/table/src/main/java/tech/ydb/table/query/BulkUpsertProtoData.java new file mode 100644 index 000000000..17d575f46 --- /dev/null +++ b/table/src/main/java/tech/ydb/table/query/BulkUpsertProtoData.java @@ -0,0 +1,29 @@ +package tech.ydb.table.query; + +import tech.ydb.proto.ValueProtos; +import tech.ydb.proto.table.YdbTable; +import tech.ydb.table.values.ListValue; + +/** + * + * @author Aleksandr Gorshenin + */ +public class BulkUpsertProtoData implements BulkUpsertData { + private final ValueProtos.TypedValue rows; + + public BulkUpsertProtoData(ListValue rows) { + this.rows = ValueProtos.TypedValue.newBuilder() + .setType(rows.getType().toPb()) + .setValue(rows.toPb()) + .build(); + } + + public BulkUpsertProtoData(ValueProtos.TypedValue rows) { + this.rows = rows; + } + + @Override + public void applyToRequest(YdbTable.BulkUpsertRequest.Builder builder) { + builder.setRows(rows); + } +} diff --git a/table/src/main/java/tech/ydb/table/utils/LittleEndian.java b/table/src/main/java/tech/ydb/table/utils/LittleEndian.java index b1c7fc56f..816d1ba1e 100644 --- a/table/src/main/java/tech/ydb/table/utils/LittleEndian.java +++ b/table/src/main/java/tech/ydb/table/utils/LittleEndian.java @@ -8,6 +8,9 @@ private LittleEndian() { } /** * Reverses the byte order of a long value. + * + * @param v long value + * @return reversed long value */ public static long bswap(long v) { return diff --git a/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java b/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java new file mode 100644 index 000000000..19254ccdc --- /dev/null +++ b/table/src/test/java/tech/ydb/table/integration/AllTypesRecord.java @@ -0,0 +1,438 @@ +package tech.ydb.table.integration; + +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; +import java.util.UUID; +import java.util.function.BiConsumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +import com.google.common.hash.Hashing; +import org.junit.Assert; + +import tech.ydb.table.description.TableColumn; +import tech.ydb.table.description.TableDescription; +import tech.ydb.table.query.ApacheArrowWriter; +import tech.ydb.table.result.ResultSetReader; +import tech.ydb.table.result.ValueReader; +import tech.ydb.table.values.DecimalType; +import tech.ydb.table.values.DecimalValue; +import tech.ydb.table.values.ListType; +import tech.ydb.table.values.ListValue; +import tech.ydb.table.values.PrimitiveType; +import tech.ydb.table.values.PrimitiveValue; +import tech.ydb.table.values.StructType; +import tech.ydb.table.values.StructValue; +import tech.ydb.table.values.Type; +import tech.ydb.table.values.Value; + +/** + * + * @author Aleksandr Gorshenin + */ +public class AllTypesRecord { + private static final DecimalType YDB_DECIMAL = DecimalType.getDefault(); + private static final DecimalType BANK_DECIMAL = DecimalType.of(31, 9); + private static final DecimalType BIG_DECIMAL = DecimalType.of(35, 0); + + // Keys + private final long id1; + private final int id2; + private final byte[] payload; + private final int length; + private final String hash ; + + // All types + private Byte v_int8; + private Short v_int16; + private Integer v_int32; + private Long v_int64; + + private Integer v_uint8; + private Integer v_uint16; + private Long v_uint32; + private Long v_uint64; + + private Boolean v_bool; + private Float v_float; + private Double v_double; + + private String v_text; + private String v_json; + private String v_jsdoc; + + private byte[] v_bytes; + private byte[] v_yson; + + private UUID v_uuid; + + private LocalDate v_date; + private LocalDateTime v_datetime; + private Instant v_timestamp; + private Duration v_interval; + + private LocalDate v_date32; + private LocalDateTime v_datetime64; + private Instant v_timestamp64; + private Duration v_interval64; + + private DecimalValue v_ydb_decimal; + private DecimalValue v_bank_decimal; + private DecimalValue v_big_decimal; + + private AllTypesRecord(long id1, int id2, byte[] payload) { + // Keys + this.id1 = id1; + this.id2 = id2; + this.length = payload.length; + this.payload = payload; + this.hash = Hashing.sha256().hashBytes(payload).toString(); + } + + private void assertValue(Set columns, String key, int idx, Function reader, + ResultSetReader rs, T value) { + if (!columns.contains(key)) { + return; + } + + String msg = "Row " + idx + " " + key; + Assert.assertEquals(msg, value != null, rs.getColumn(key).isOptionalItemPresent()); + + if (value != null) { + if (value instanceof byte[]) { + Assert.assertArrayEquals(msg, (byte[]) value, (byte[]) reader.apply(rs.getColumn(key))); + } else { + Assert.assertEquals(msg, value, reader.apply(rs.getColumn(key))); + } + } + } + + public void assertRow(Set columns, int idx, ResultSetReader rs) { + // keys are required + Assert.assertEquals("Row " + idx + " id1", id1, rs.getColumn("id1").getUint64()); + Assert.assertEquals("Row " + idx + " id2", id2, rs.getColumn("id2").getInt64()); + Assert.assertArrayEquals("Row " + idx + " payload", payload, rs.getColumn("payload").getBytes()); + Assert.assertEquals("Row " + idx + " length", length, rs.getColumn("length").getUint32()); + Assert.assertEquals("Row " + idx + " hash", hash, rs.getColumn("hash").getText()); + + // other columns may be skipped or be empty + assertValue(columns, "Int8", idx, ValueReader::getInt8, rs, v_int8); + assertValue(columns, "Int16", idx, ValueReader::getInt16, rs, v_int16); + assertValue(columns, "Int32", idx, ValueReader::getInt32, rs, v_int32); + assertValue(columns, "Int64", idx, ValueReader::getInt64, rs, v_int64); + + assertValue(columns, "Uint8", idx, ValueReader::getUint8, rs, v_uint8); + assertValue(columns, "Uint16", idx, ValueReader::getUint16, rs, v_uint16); + assertValue(columns, "Uint32", idx, ValueReader::getUint32, rs, v_uint32); + assertValue(columns, "Uint64", idx, ValueReader::getUint64, rs, v_uint64); + + assertValue(columns, "Bool", idx, ValueReader::getBool, rs, v_bool); + assertValue(columns, "Float", idx, ValueReader::getFloat, rs, v_float); + assertValue(columns, "Double", idx, ValueReader::getDouble, rs, v_double); + + assertValue(columns, "Text", idx, ValueReader::getText, rs, v_text); + assertValue(columns, "Json", idx, ValueReader::getJson, rs, v_json); + assertValue(columns, "JsonDocument", idx, ValueReader::getJsonDocument, rs, v_jsdoc); + + assertValue(columns, "Uuid", idx, ValueReader::getUuid, rs, v_uuid); + + assertValue(columns, "Bytes", idx, ValueReader::getBytes, rs, v_bytes); + assertValue(columns, "Yson", idx, ValueReader::getYson, rs, v_yson); + + assertValue(columns, "Date", idx, ValueReader::getDate, rs, v_date); + assertValue(columns, "Datetime", idx, ValueReader::getDatetime, rs, v_datetime); + assertValue(columns, "Timestamp", idx, ValueReader::getTimestamp, rs, v_timestamp); + assertValue(columns, "Interval", idx, ValueReader::getInterval, rs, v_interval); + + assertValue(columns, "Date32", idx, ValueReader::getDate32, rs, v_date32); + assertValue(columns, "Datetime64", idx, ValueReader::getDatetime64, rs, v_datetime64); + assertValue(columns, "Timestamp64", idx, ValueReader::getTimestamp64, rs, v_timestamp64); + assertValue(columns, "Interval64", idx, ValueReader::getInterval64, rs, v_interval64); + + assertValue(columns, "YdbDecimal", idx, ValueReader::getDecimal, rs, v_ydb_decimal); + assertValue(columns, "BankDecimal", idx, ValueReader::getDecimal, rs, v_bank_decimal); + assertValue(columns, "BigDecimal", idx, ValueReader::getDecimal, rs, v_big_decimal); + } + + private Value makePb(Type type, Function> func, T value) { + return value != null ? func.apply(value) : type.makeOptional().emptyValue(); + } + + private Map> toPb(List columnOnly) { + Map> struct = new HashMap<>(); + + BiConsumer> write = (key, value) -> { + if (columnOnly.contains(key)) { + struct.put(key, value); + } + }; + + // keys are required + struct.put("id1", PrimitiveValue.newUint64(id1)); + struct.put("id2", PrimitiveValue.newInt64(id2)); + struct.put("payload", PrimitiveValue.newBytesOwn(payload)); + struct.put("length", PrimitiveValue.newUint32(length)); + struct.put("hash", PrimitiveValue.newText(hash)); + + // other columns may be skipped or be empty + write.accept("Int8", makePb(PrimitiveType.Int8, PrimitiveValue::newInt8, v_int8)); + write.accept("Int16", makePb(PrimitiveType.Int16, PrimitiveValue::newInt16, v_int16)); + write.accept("Int32", makePb(PrimitiveType.Int32, PrimitiveValue::newInt32, v_int32)); + write.accept("Int64", makePb(PrimitiveType.Int64, PrimitiveValue::newInt64, v_int64)); + + write.accept("Uint8", makePb(PrimitiveType.Uint8, PrimitiveValue::newUint8, v_uint8)); + write.accept("Uint16", makePb(PrimitiveType.Uint16, PrimitiveValue::newUint16, v_uint16)); + write.accept("Uint32", makePb(PrimitiveType.Uint32, PrimitiveValue::newUint32, v_uint32)); + write.accept("Uint64", makePb(PrimitiveType.Uint64, PrimitiveValue::newUint64, v_uint64)); + + write.accept("Bool", makePb(PrimitiveType.Bool, PrimitiveValue::newBool, v_bool)); + write.accept("Float", makePb(PrimitiveType.Float, PrimitiveValue::newFloat, v_float)); + write.accept("Double", makePb(PrimitiveType.Double, PrimitiveValue::newDouble, v_double)); + + write.accept("Text", makePb(PrimitiveType.Text, PrimitiveValue::newText, v_text)); + write.accept("Json", makePb(PrimitiveType.Json, PrimitiveValue::newJson, v_json)); + write.accept("JsonDocument", makePb(PrimitiveType.JsonDocument, PrimitiveValue::newJsonDocument, v_jsdoc)); + + write.accept("Bytes", makePb(PrimitiveType.Bytes, PrimitiveValue::newBytes, v_bytes)); + write.accept("Yson", makePb(PrimitiveType.Yson, PrimitiveValue::newYson, v_yson)); + + write.accept("Uuid", makePb(PrimitiveType.Uuid, PrimitiveValue::newUuid, v_uuid)); + + write.accept("Date", makePb(PrimitiveType.Date, PrimitiveValue::newDate, v_date)); + write.accept("Datetime", makePb(PrimitiveType.Datetime, PrimitiveValue::newDatetime, v_datetime)); + write.accept("Timestamp", makePb(PrimitiveType.Timestamp, PrimitiveValue::newTimestamp, v_timestamp)); + write.accept("Interval", makePb(PrimitiveType.Interval, PrimitiveValue::newInterval, v_interval)); + + write.accept("Date32", makePb(PrimitiveType.Date32, PrimitiveValue::newDate32, v_date32)); + write.accept("Datetime64", makePb(PrimitiveType.Datetime64, PrimitiveValue::newDatetime64, v_datetime64)); + write.accept("Timestamp64", makePb(PrimitiveType.Timestamp64, PrimitiveValue::newTimestamp64, v_timestamp64)); + write.accept("Interval64", makePb(PrimitiveType.Interval64, PrimitiveValue::newInterval64, v_interval64)); + + write.accept("YdbDecimal", makePb(YDB_DECIMAL, Function.identity(), v_ydb_decimal)); + write.accept("BankDecimal", makePb(BANK_DECIMAL, Function.identity(), v_bank_decimal)); + write.accept("BigDecimal", makePb(BIG_DECIMAL, Function.identity(), v_big_decimal)); + + return struct; + } + + private void writeNullable(Set columns, String key, ApacheArrowWriter.Row row, + BiConsumer writer, T value) { + if (!columns.contains(key)) { + return; + } + + if (value == null) { + row.writeNull(key); + } else { + writer.accept(key, value); + } + } + + public void writeToApacheArrow(Set columnNames, ApacheArrowWriter.Row row) { + // keys are required + row.writeUint64("id1", id1); + row.writeInt64("id2", id2); + row.writeBytes("payload", payload); + row.writeUint32("length", length); + row.writeText("hash", hash); + + // other columns may be skipped or be empty + writeNullable(columnNames, "Int8", row, row::writeInt8, v_int8); + writeNullable(columnNames, "Int16", row, row::writeInt16, v_int16); + writeNullable(columnNames, "Int32", row, row::writeInt32, v_int32); + writeNullable(columnNames, "Int64", row, row::writeInt64, v_int64); + + writeNullable(columnNames, "Uint8", row, row::writeUint8, v_uint8); + writeNullable(columnNames, "Uint16", row, row::writeUint16, v_uint16); + writeNullable(columnNames, "Uint32", row, row::writeUint32, v_uint32); + writeNullable(columnNames, "Uint64", row, row::writeUint64, v_uint64); + + writeNullable(columnNames, "Bool", row, row::writeBool, v_bool); + + writeNullable(columnNames, "Float", row, row::writeFloat, v_float); + writeNullable(columnNames, "Double", row, row::writeDouble, v_double); + + writeNullable(columnNames, "Text", row, row::writeText, v_text); + writeNullable(columnNames, "Json", row, row::writeJson, v_json); + writeNullable(columnNames, "JsonDocument", row, row::writeJsonDocument, v_jsdoc); + + writeNullable(columnNames, "Bytes", row, row::writeBytes, v_bytes); + writeNullable(columnNames, "Yson", row, row::writeYson, v_yson); + + writeNullable(columnNames, "Uuid", row, row::writeUuid, v_uuid); + + writeNullable(columnNames, "Date", row, row::writeDate, v_date); + writeNullable(columnNames, "Datetime", row, row::writeDatetime, v_datetime); + writeNullable(columnNames, "Timestamp", row, row::writeTimestamp, v_timestamp); + writeNullable(columnNames, "Interval", row, row::writeInterval, v_interval); + + writeNullable(columnNames, "Date32", row, row::writeDate32, v_date32); + writeNullable(columnNames, "Datetime64", row, row::writeDatetime64, v_datetime64); + writeNullable(columnNames, "Timestamp64", row, row::writeTimestamp64, v_timestamp64); + writeNullable(columnNames, "Interval64", row, row::writeInterval64, v_interval64); + + writeNullable(columnNames, "YdbDecimal", row, row::writeDecimal, v_ydb_decimal); + writeNullable(columnNames, "BankDecimal", row, row::writeDecimal, v_bank_decimal); + writeNullable(columnNames, "BigDecimal", row, row::writeDecimal, v_big_decimal); + } + + public static DecimalValue randomDecimal(DecimalType type, Random rnd) { + int kind = rnd.nextInt(1000); + switch (kind) { + case 0: return type.getNegInf(); + case 499: return type.newValue(0); + case 999: return type.getInf(); + default: + break; + } + BigInteger unscaled = new BigInteger(1 + rnd.nextInt(type.getPrecision() * 3 - 1), rnd); + if (kind < 500) { + unscaled = unscaled.negate(); + } + return type.newValueUnscaled(unscaled); + } + + private static T nullable(Random rnd, T value) { + return rnd.nextInt(10) == 0 ? null : value; + } + + public static AllTypesRecord random(long id1, int id2, Random rnd) { + int length = 10 + rnd.nextInt(256); + byte[] payload = new byte[length]; + rnd.nextBytes(payload); + + AllTypesRecord r = new AllTypesRecord(id1, id2, payload); + + // All types + r.v_int8 = nullable(rnd, (byte) rnd.nextInt()); + r.v_int16 = nullable(rnd, (short) rnd.nextInt()); + r.v_int32 = nullable(rnd, rnd.nextInt()); + r.v_int64 = nullable(rnd, rnd.nextLong()); + + r.v_uint8 = nullable(rnd, rnd.nextInt() & 0xFF); + r.v_uint16 = nullable(rnd, rnd.nextInt() & 0xFFFF); + r.v_uint32 = nullable(rnd, rnd.nextLong() & 0xFFFFFFFFL); + r.v_uint64 = nullable(rnd, rnd.nextLong()); + + r.v_bool = nullable(rnd, rnd.nextBoolean()); + r.v_float = nullable(rnd, rnd.nextFloat()); + r.v_double = nullable(rnd, rnd.nextDouble()); + + r.v_text = nullable(rnd, "Text" + rnd.nextInt(1000)); + r.v_json = nullable(rnd, "{\"json\":" + (1000 + rnd.nextInt(1000)) + "}"); + r.v_jsdoc = nullable(rnd, "{\"document\":" + (2000 + rnd.nextInt(1000)) + "}"); + + r.v_bytes = nullable(rnd, ("Bytes " + (3000 + rnd.nextInt(1000))).getBytes(StandardCharsets.UTF_8)); + r.v_yson = nullable(rnd, ("{yson=" + (4000 + rnd.nextInt(1000)) + "}").getBytes(StandardCharsets.UTF_8)); + + r.v_uuid = nullable(rnd, UUID.nameUUIDFromBytes(("UUID" + rnd.nextInt()).getBytes(StandardCharsets.UTF_8))); + + r.v_date = nullable(rnd, LocalDate.ofEpochDay(rnd.nextInt(5000))); + r.v_datetime = nullable(rnd, LocalDateTime.ofEpochSecond(0x60FFFFFF & rnd.nextLong(), 0, ZoneOffset.UTC)); + r.v_timestamp = nullable(rnd, Instant.ofEpochSecond(0x3FFFFFFFL & rnd.nextLong(), + rnd.nextInt(1000000) * 1000)); + r.v_interval = nullable(rnd, Duration.ofNanos((0x7FFFFFFFFFL & rnd.nextLong() - 0x4000000000L) * 1000)); + + r.v_date32 = nullable(rnd, LocalDate.ofEpochDay(rnd.nextInt(5000) - 2500)); + r.v_datetime64 = nullable(rnd, LocalDateTime.ofEpochSecond(0x60FFFFFF & rnd.nextLong() - 0x30FFFFFF, 0, + ZoneOffset.UTC)); + r.v_timestamp64 = nullable(rnd, Instant.ofEpochSecond(0x7FFFFFFFFFL & rnd.nextLong() - 0x4000000000L, + rnd.nextInt(1000000) * 1000)); + r.v_interval64 = nullable(rnd, Duration.ofNanos((0x7FFFFFFFFFL & rnd.nextLong() - 0x4000000000L) * 1000)); + + r.v_ydb_decimal = nullable(rnd, randomDecimal(YDB_DECIMAL, rnd)); + r.v_bank_decimal = nullable(rnd, randomDecimal(BANK_DECIMAL, rnd)); + r.v_big_decimal = nullable(rnd, randomDecimal(BIG_DECIMAL, rnd)); + + return r; + } + + public static List randomBatch(int id1, int id2_start, int count) { + Random rnd = new Random(id1 * count + id2_start); + List batch = new ArrayList<>(count); + for (int idx = 0; idx < count; idx += 1) { + batch.add(random(id1, id2_start + idx, rnd)); + } + return batch; + } + + public static ListValue createProtobufBatch(TableDescription desc, List batch) { + List columnNames = desc.getColumns().stream().map(TableColumn::getName).collect(Collectors.toList()); + List columnTypes = desc.getColumns().stream().map(TableColumn::getType).collect(Collectors.toList()); + + StructType type = StructType.of(columnNames, columnTypes); + List values = batch.stream().map(r -> type.newValue(r.toPb(columnNames))) + .collect(Collectors.toList()); + + return ListType.of(type).newValue(values); + } + + public static TableDescription createTableDescription(boolean isColumnShard) { + TableDescription.Builder builder = TableDescription.newBuilder() + .addNonnullColumn("id1", PrimitiveType.Uint64) + .addNonnullColumn("id2", PrimitiveType.Int64) + .addNonnullColumn("payload", PrimitiveType.Bytes) + .addNonnullColumn("length", PrimitiveType.Uint32) + .addNonnullColumn("hash", PrimitiveType.Text) + + .addNullableColumn("Int8", PrimitiveType.Int8) + .addNullableColumn("Int16", PrimitiveType.Int16) + .addNullableColumn("Int32", PrimitiveType.Int32) + .addNullableColumn("Int64", PrimitiveType.Int64) + .addNullableColumn("Uint8", PrimitiveType.Uint8) + .addNullableColumn("Uint16", PrimitiveType.Uint16) + .addNullableColumn("Uint32", PrimitiveType.Uint32) + .addNullableColumn("Uint64", PrimitiveType.Uint64) + .addNullableColumn("Bool", PrimitiveType.Bool) + .addNullableColumn("Float", PrimitiveType.Float) + .addNullableColumn("Double", PrimitiveType.Double) + .addNullableColumn("Text", PrimitiveType.Text) + .addNullableColumn("Json", PrimitiveType.Json) + .addNullableColumn("JsonDocument", PrimitiveType.JsonDocument) + .addNullableColumn("Bytes", PrimitiveType.Bytes) + .addNullableColumn("Yson", PrimitiveType.Yson); + + // https://github.com/ydb-platform/ydb/issues/13047 + if (!isColumnShard) { + builder = builder.addNullableColumn("Uuid", PrimitiveType.Uuid); + } + + builder = builder + .addNullableColumn("Date", PrimitiveType.Date) + .addNullableColumn("Datetime", PrimitiveType.Datetime) + .addNullableColumn("Timestamp", PrimitiveType.Timestamp); + + // https://github.com/ydb-platform/ydb/issues/13050 + if (!isColumnShard) { + builder = builder.addNullableColumn("Interval", PrimitiveType.Interval); + } + + builder = builder + .addNullableColumn("Date32", PrimitiveType.Date32) + .addNullableColumn("Datetime64", PrimitiveType.Datetime64) + .addNullableColumn("Timestamp64", PrimitiveType.Timestamp64) + .addNullableColumn("Interval64", PrimitiveType.Interval64) + .addNullableColumn("YdbDecimal", YDB_DECIMAL) + .addNullableColumn("BankDecimal", BANK_DECIMAL) + .addNullableColumn("BigDecimal", BIG_DECIMAL); + + if (isColumnShard) { + builder = builder.setPrimaryKey("hash").setStoreType(TableDescription.StoreType.COLUMN); + } else { + builder = builder.setPrimaryKeys("id1", "id2").setStoreType(TableDescription.StoreType.ROW); + } + + return builder.build(); + } +} diff --git a/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java b/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java new file mode 100644 index 000000000..b67d990aa --- /dev/null +++ b/table/src/test/java/tech/ydb/table/integration/BulkUpsertTest.java @@ -0,0 +1,254 @@ +package tech.ydb.table.integration; + +import java.io.IOException; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CompletionException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; +import java.util.stream.Collectors; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.After; +import org.junit.Assert; +import org.junit.ClassRule; +import org.junit.Test; + +import tech.ydb.core.grpc.GrpcReadStream; +import tech.ydb.table.SessionRetryContext; +import tech.ydb.table.description.TableColumn; +import tech.ydb.table.description.TableDescription; +import tech.ydb.table.impl.SimpleTableClient; +import tech.ydb.table.query.ApacheArrowWriter; +import tech.ydb.table.query.BulkUpsertData; +import tech.ydb.table.query.Params; +import tech.ydb.table.result.ResultSetReader; +import tech.ydb.table.rpc.grpc.GrpcTableRpc; +import tech.ydb.table.settings.ExecuteScanQuerySettings; +import tech.ydb.table.settings.ExecuteSchemeQuerySettings; +import tech.ydb.table.values.ListValue; +import tech.ydb.table.values.PrimitiveValue; +import tech.ydb.test.junit4.GrpcTransportRule; + +/** + * + * @author Aleksandr Gorshenin + */ +public class BulkUpsertTest { + @ClassRule + public static final GrpcTransportRule YDB = new GrpcTransportRule(); + private static final SimpleTableClient client = SimpleTableClient.newClient(GrpcTableRpc.useTransport(YDB)).build(); + private static final SessionRetryContext retryCtx = SessionRetryContext.create(client) + .idempotent(false) + .build(); + + private static final String TEST_TABLE = "arrow/test_table"; + private static final String DROP_TABLE_YQL = "DROP TABLE IF EXISTS `" + TEST_TABLE + "`;"; + private static final String SELECT_TABLE_YQL = "DECLARE $id1 AS Uint64; SELECT * FROM `" + TEST_TABLE + "` " + + "WHERE id1 = $id1 ORDER BY id2"; + + private static String tablePath() { + return YDB.getDatabase() + "/" + TEST_TABLE; + } + + @After + public void cleanTable() { + retryCtx.supplyStatus(session -> session.executeSchemeQuery(DROP_TABLE_YQL, new ExecuteSchemeQuerySettings())) + .join().expectSuccess("cannot drop table"); + } + + private static void createTable(TableDescription table) { + retryCtx.supplyStatus(s -> s.createTable(tablePath(), table)) + .join().expectSuccess("Cannot create table"); + } + + private static void bulkUpsert(BulkUpsertData data) { + retryCtx.supplyStatus(session -> session.executeBulkUpsert(tablePath(), data)) + .join().expectSuccess("bulk upsert problem in table " + tablePath()); + } + + private static void bulkUpsert(ListValue rows) { + retryCtx.supplyStatus(session -> session.executeBulkUpsert(tablePath(), rows)) + .join().expectSuccess("bulk upsert problem in table " + tablePath()); + } + + private static int readTable(long id1, BiConsumer validator) { + AtomicInteger count = new AtomicInteger(); + + try { + retryCtx.supplyStatus(session -> { + count.set(0); + + GrpcReadStream stream = session.executeScanQuery(SELECT_TABLE_YQL, + Params.of("$id1", PrimitiveValue.newUint64(id1)), + ExecuteScanQuerySettings.newBuilder().build() + ); + + return stream.start((rs) -> { + while (rs.next()) { + int idx = count.getAndIncrement(); + validator.accept(idx, rs); + } + }); + }).join().expectSuccess("Cannot read table " + TEST_TABLE); + } catch (CompletionException ex) { + if (ex.getCause() instanceof AssertionError) { + throw (AssertionError) ex.getCause(); + } + throw ex; + } + + return count.get(); + } + + @Test + public void writeProtobufToDataShardTest() { + // Create table + TableDescription table = AllTypesRecord.createTableDescription(false); + createTable(table); + + Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet()); + + // Write & read batch of 1000 records with id1 = 1 + List batch1 = AllTypesRecord.randomBatch(1, 1, 1000); + bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch1)); + + int rows1Count = readTable(1, (idx, rs) -> { + Assert.assertTrue("Unexpected row index", idx < batch1.size()); + batch1.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(1000, rows1Count); + + // Write & read batch of 2000 records with id1 = 2 + List batch2 = AllTypesRecord.randomBatch(2, 1, 2000); + bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch2)); + + int rows2Count = readTable(2, (idx, rs) -> { + Assert.assertTrue("Unexpected row index", idx < batch2.size()); + batch2.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(2000, rows2Count); + } + + @Test + public void writeProtobufToColumnShardTable() { + // Create table + TableDescription table = AllTypesRecord.createTableDescription(true); + Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet()); + + createTable(table); + + // Write & read batch of 2500 records with id1 = 1 + List batch1 = AllTypesRecord.randomBatch(1, 1, 2500); + bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch1)); + // Read table and validate data + int rows1Count = readTable(1, (idx, rs) -> { + Assert.assertTrue("Unexpected row index", idx < batch1.size()); + batch1.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(2500, rows1Count); + + // Write & read batch of 5000 records with id1 = 2 + List batch2 = AllTypesRecord.randomBatch(2, 1, 5000); + bulkUpsert(AllTypesRecord.createProtobufBatch(table, batch2)); + + int rows2Count = readTable(2, (idx, rs) -> { + Assert.assertTrue("Unexpected row index", idx < batch2.size()); + batch2.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(5000, rows2Count); + } + + @Test + public void writeApacheArrowToDataShardTest() { + // Create table + TableDescription table = AllTypesRecord.createTableDescription(false); + retryCtx.supplyStatus(s -> s.createTable(tablePath(), table)).join().expectSuccess("Cannot create table"); + + Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet()); + + ApacheArrowWriter.Schema schema = ApacheArrowWriter.newSchema(); + table.getColumns().forEach(column -> schema.addColumn(column.getName(), column.getType())); + + // Create batch of 1000 records + List batch1 = AllTypesRecord.randomBatch(1, 1, 1000); + // Create batch of 2000 records + List batch2 = AllTypesRecord.randomBatch(2, 1, 2000); + + try (BufferAllocator allocator = new RootAllocator()) { + try (ApacheArrowWriter writer = schema.createWriter(allocator)) { + // create batch with estimated size + ApacheArrowWriter.Batch data1 = writer.createNewBatch(1000); + batch1.forEach(r -> r.writeToApacheArrow(columnNames, data1.writeNextRow())); + bulkUpsert(data1.buildBatch()); + + // create batch without estimated size + ApacheArrowWriter.Batch data2 = writer.createNewBatch(0); + batch2.forEach(r -> r.writeToApacheArrow(columnNames, data2.writeNextRow())); + bulkUpsert(data2.buildBatch()); + + } catch (IOException ex) { + throw new AssertionError("Cannot serialize apache arrow", ex); + } + } + + int rows1Count = readTable(1, (idx, rs) -> { + Assert.assertTrue("Unexpected row index", idx < batch1.size()); + batch1.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(1000, rows1Count); + + int rows2Count = readTable(2, (idx, rs) -> { + Assert.assertTrue("Unexpected row index ", idx < batch2.size()); + batch2.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(2000, rows2Count); + } + + @Test + public void writeApacheArrowToColumnShardTest() { + // Create table + TableDescription table = AllTypesRecord.createTableDescription(true); + retryCtx.supplyStatus(s -> s.createTable(tablePath(), table)).join().expectSuccess("Cannot create table"); + + Set columnNames = table.getColumns().stream().map(TableColumn::getName).collect(Collectors.toSet()); + + ApacheArrowWriter.Schema schema = ApacheArrowWriter.newSchema(); + table.getColumns().forEach(column -> schema.addColumn(column.getName(), column.getType())); + + // Create batch of 2500 records + List batch1 = AllTypesRecord.randomBatch(1, 1, 2500); + // Create batch of 5000 records + List batch2 = AllTypesRecord.randomBatch(2, 1, 5000); + + try (BufferAllocator allocator = new RootAllocator()) { + try (ApacheArrowWriter writer = schema.createWriter(allocator)) { + // create batch without estimated size + ApacheArrowWriter.Batch data1 = writer.createNewBatch(0); + batch1.forEach(r -> r.writeToApacheArrow(columnNames, data1.writeNextRow())); + bulkUpsert(data1.buildBatch()); + + // create batch with estimated size + ApacheArrowWriter.Batch data2 = writer.createNewBatch(5000); + batch2.forEach(r -> r.writeToApacheArrow(columnNames, data2.writeNextRow())); + bulkUpsert(data2.buildBatch()); + + } catch (IOException ex) { + throw new AssertionError("Cannot serialize apache arrow", ex); + } + } + + int rows1Count = readTable(1, (idx, rs) -> { + Assert.assertTrue("Unexpected row index", idx < batch1.size()); + batch1.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(2500, rows1Count); + + int rows2Count = readTable(2, (idx, rs) -> { + Assert.assertTrue("Unexpected row index ", idx < batch2.size()); + batch2.get(idx).assertRow(columnNames, idx, rs); + }); + Assert.assertEquals(5000, rows2Count); + } +} diff --git a/table/src/test/java/tech/ydb/table/integration/ReadTableTest.java b/table/src/test/java/tech/ydb/table/integration/ReadTableTest.java index 1d736c3af..d86996b49 100644 --- a/table/src/test/java/tech/ydb/table/integration/ReadTableTest.java +++ b/table/src/test/java/tech/ydb/table/integration/ReadTableTest.java @@ -99,7 +99,7 @@ public static void prepareTable() { }).collect(Collectors.toList()); retryCtx.supplyStatus(session -> session.executeBulkUpsert(tablePath, - new BulkUpsertData(ProtoValue.toTypedValue(ListType.of(batchType).newValue(batchData))) + BulkUpsertData.fromProto(ProtoValue.toTypedValue(ListType.of(batchType).newValue(batchData))) )).join().expectSuccess("bulk upsert problem in table " + tablePath); } diff --git a/table/src/test/java/tech/ydb/table/integration/ValuesReadTest.java b/table/src/test/java/tech/ydb/table/integration/ValuesReadTest.java index ba91f67f4..89d1faf07 100644 --- a/table/src/test/java/tech/ydb/table/integration/ValuesReadTest.java +++ b/table/src/test/java/tech/ydb/table/integration/ValuesReadTest.java @@ -468,7 +468,6 @@ public void date32datetime64timestamp64interval64() { @Test public void timestamp64ReadTest() { - System.out.println(Instant.ofEpochSecond(-4611669897600L)); DataQueryResult result = CTX.supplyResult( s -> s.executeDataQuery("SELECT " + "Timestamp64('-144169-01-01T00:00:00Z') as t1," diff --git a/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java b/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java new file mode 100644 index 000000000..4f48787cb --- /dev/null +++ b/table/src/test/java/tech/ydb/table/query/ApacheArrowWriterTest.java @@ -0,0 +1,402 @@ +package tech.ydb.table.query; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.channels.Channels; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.UUID; + +import com.google.protobuf.ByteString; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.function.ThrowingRunnable; + +import tech.ydb.table.values.DecimalType; +import tech.ydb.table.values.DecimalValue; +import tech.ydb.table.values.ListType; +import tech.ydb.table.values.PrimitiveType; + +/** + * + * @author Aleksandr Gorshenin + */ +public class ApacheArrowWriterTest { + + private void assertIllegalArgument(String message, ThrowingRunnable runnable) { + IllegalArgumentException ex = Assert.assertThrows(IllegalArgumentException.class, runnable); + Assert.assertEquals(message, ex.getMessage()); + } + + private void assertIllegalState(String message, ThrowingRunnable runnable) { + IllegalStateException ex = Assert.assertThrows(IllegalStateException.class, runnable); + Assert.assertEquals(message, ex.getMessage()); + } + + private static RootAllocator allocator; + + @BeforeClass + public static void initAllocator() { + allocator = new RootAllocator(); + } + + @AfterClass + public static void cleanAllocator() { + allocator.close(); + } + + @Test + public void unsupportedTypesTest() { + // tz date is not supported + assertIllegalArgument("Type TzDate is not supported in ArrowWriter", () -> ApacheArrowWriter.newSchema() + .addColumn("col", PrimitiveType.TzDate) + .createWriter(allocator)); + + // complex types are not supported (except Optional) + assertIllegalArgument("Type Int32? is not supported in ArrowWriter", () -> ApacheArrowWriter.newSchema() + .addNullableColumn("col", PrimitiveType.Int32.makeOptional()) + .createWriter(allocator)); + + // Non primitive types are not supported (except Decimal) + assertIllegalArgument("Type List is not supported in ArrowWriter", () -> ApacheArrowWriter.newSchema() + .addColumn("col", ListType.of(PrimitiveType.Int32)) + .createWriter(allocator)); + } + + @Test + public void invalidColumnTest() { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("col", PrimitiveType.Uuid) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + assertIllegalArgument("Column 'col2' not found", () -> row.writeUuid("col2", UUID.randomUUID())); + } + } + + @Test + public void nullableTypeTest() { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("col1", PrimitiveType.Uuid) + .addColumn("col2", PrimitiveType.Uuid.makeOptional()) + .addNullableColumn("col3", PrimitiveType.Uuid) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + assertIllegalState("cannot call writeNull, actual type: Uuid", () -> row.writeNull("col1")); + row.writeNull("col2"); // success + row.writeNull("col3"); // success + } + } + + @Test + public void baseTypeValidationTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("col1", PrimitiveType.Uuid) + .addColumn("col2", PrimitiveType.Int32) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeBool, actual type: Uuid", () -> row.writeBool("col1", true)); + + assertIllegalState("cannot call writeInt8, actual type: Uuid", () -> row.writeInt8("col1", (byte) 0)); + assertIllegalState("cannot call writeInt16, actual type: Uuid", () -> row.writeInt16("col1", (short) 0)); + assertIllegalState("cannot call writeInt32, actual type: Uuid", () -> row.writeInt32("col1", 0)); + assertIllegalState("cannot call writeInt64, actual type: Uuid", () -> row.writeInt64("col1", 0)); + + assertIllegalState("cannot call writeUint8, actual type: Uuid", () -> row.writeUint8("col1", 0)); + assertIllegalState("cannot call writeUint16, actual type: Uuid", () -> row.writeUint16("col1", 0)); + assertIllegalState("cannot call writeUint32, actual type: Uuid", () -> row.writeUint32("col1", 0)); + assertIllegalState("cannot call writeUint64, actual type: Uuid", () -> row.writeUint64("col1", 0)); + + assertIllegalState("cannot call writeFloat, actual type: Uuid", () -> row.writeFloat("col1", 0)); + assertIllegalState("cannot call writeDouble, actual type: Uuid", () -> row.writeDouble("col1", 0)); + + assertIllegalState("cannot call writeText, actual type: Uuid", () -> row.writeText("col1", "")); + assertIllegalState("cannot call writeJson, actual type: Uuid", () -> row.writeJson("col1", "")); + assertIllegalState("cannot call writeJsonDocument, actual type: Uuid", + () -> row.writeJsonDocument("col1", "")); + + assertIllegalState("cannot call writeBytes, actual type: Uuid", + () -> row.writeBytes("col1", new byte[0])); + assertIllegalState("cannot call writeYson, actual type: Uuid", + () -> row.writeYson("col1", new byte[0])); + + assertIllegalState("cannot call writeDate, actual type: Uuid", + () -> row.writeDate("col1", LocalDate.ofEpochDay(0))); + assertIllegalState("cannot call writeDatetime, actual type: Uuid", + () -> row.writeDatetime("col1", LocalDateTime.now())); + assertIllegalState("cannot call writeTimestamp, actual type: Uuid", + () -> row.writeTimestamp("col1", Instant.ofEpochSecond(0))); + assertIllegalState("cannot call writeInterval, actual type: Uuid", + () -> row.writeInterval("col1", Duration.ZERO)); + + assertIllegalState("cannot call writeDate32, actual type: Uuid", + () -> row.writeDate32("col1", LocalDate.ofEpochDay(0))); + assertIllegalState("cannot call writeDatetime64, actual type: Uuid", + () -> row.writeDatetime64("col1", LocalDateTime.now())); + assertIllegalState("cannot call writeTimestamp64, actual type: Uuid", + () -> row.writeTimestamp64("col1", Instant.ofEpochSecond(0))); + assertIllegalState("cannot call writeInterval64, actual type: Uuid", + () -> row.writeInterval64("col1", Duration.ZERO)); + + // second column + assertIllegalState("cannot call writeDecimal, actual type: Int32", + () -> row.writeDecimal("col2", DecimalType.getDefault().newValue(0))); + assertIllegalState("cannot call writeUuid, actual type: Int32", + () -> row.writeUuid("col2", UUID.randomUUID())); + } + } + + @Test + public void tinyIntVectorTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Bool) + .addNullableColumn("c2", PrimitiveType.Int8) + .addNullableColumn("c3", PrimitiveType.Uint8) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeInt8, actual type: Bool", () -> row.writeInt8("c1", (byte) 0)); + assertIllegalState("cannot call writeUint8, actual type: Int8", () -> row.writeUint8("c2", 0)); + assertIllegalState("cannot call writeBool, actual type: Uint8", () -> row.writeBool("c3", false)); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", + schema.toString()); + } + } + + @Test + public void smallIntVectorTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Date) + .addNullableColumn("c2", PrimitiveType.Int16) + .addNullableColumn("c3", PrimitiveType.Uint16) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeUint16, actual type: Date", () -> row.writeUint16("c1", 0)); + assertIllegalState("cannot call writeDate, actual type: Int16", + () -> row.writeDate("c2", LocalDate.ofEpochDay(0))); + assertIllegalState("cannot call writeInt16, actual type: Uint16", () -> row.writeInt16("c3", (short) 0)); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", + schema.toString()); + } + } + + @Test + public void intVectorTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Int32) + .addColumn("c2", PrimitiveType.Uint32) + .addNullableColumn("c3", PrimitiveType.Date32) + .addNullableColumn("c4", PrimitiveType.Datetime) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeUint32, actual type: Int32", () -> row.writeUint32("c1", 0)); + assertIllegalState("cannot call writeDate32, actual type: Uint32", + () -> row.writeDate32("c2", LocalDate.ofEpochDay(0))); + assertIllegalState("cannot call writeDatetime, actual type: Date32", + () -> row.writeDatetime("c3", LocalDateTime.now())); + assertIllegalState("cannot call writeInt32, actual type: Datetime", () -> row.writeInt32("c4", 0)); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", schema.toString()); + } + } + + @Test + public void bigIntVectorTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Int64) + .addColumn("c2", PrimitiveType.Uint64) + .addNullableColumn("c3", PrimitiveType.Datetime64) + .addNullableColumn("c4", PrimitiveType.Timestamp) + .addNullableColumn("c5", PrimitiveType.Timestamp64) + .addNullableColumn("c6", PrimitiveType.Interval) + .addNullableColumn("c7", PrimitiveType.Interval64) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeUint64, actual type: Int64", () -> row.writeUint64("c1", 0)); + assertIllegalState("cannot call writeDatetime64, actual type: Uint64", + () -> row.writeDatetime64("c2", LocalDateTime.now())); + assertIllegalState("cannot call writeTimestamp, actual type: Datetime64", + () -> row.writeTimestamp("c3", Instant.now())); + assertIllegalState("cannot call writeTimestamp64, actual type: Timestamp", + () -> row.writeTimestamp64("c4", Instant.now())); + assertIllegalState("cannot call writeInterval, actual type: Timestamp64", + () -> row.writeInterval("c5", Duration.ZERO)); + assertIllegalState("cannot call writeInterval64, actual type: Interval", + () -> row.writeInterval64("c6", Duration.ZERO)); + assertIllegalState("cannot call writeInt64, actual type: Interval64", + () -> row.writeInt64("c7", 0)); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", + schema.toString()); + } + } + + @Test + public void varCharVectorTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Text) + .addNullableColumn("c2", PrimitiveType.Json) + .addColumn("c3", PrimitiveType.JsonDocument) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeJson, actual type: Text", () -> row.writeJson("c1", "")); + assertIllegalState("cannot call writeJsonDocument, actual type: Json", () -> row.writeJsonDocument("c2", "")); + assertIllegalState("cannot call writeText, actual type: JsonDocument", () -> row.writeText("c3", "")); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", schema.toString()); + } + } + + @Test + public void varBinaryVectorTest() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Bytes) + .addNullableColumn("c2", PrimitiveType.Yson) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeYson, actual type: Bytes", () -> row.writeYson("c1", new byte[0])); + assertIllegalState("cannot call writeBytes, actual type: Yson", () -> row.writeBytes("c2", new byte[0])); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", schema.toString()); + } + } + + @Test + public void fixedSizeBinaryVectorTest() throws IOException { + DecimalValue dv = DecimalType.getDefault().newValue(0); + UUID uv = UUID.randomUUID(); + + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("c1", PrimitiveType.Uuid) + .addNullableColumn("c2", DecimalType.getDefault()) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(0); + ApacheArrowWriter.Row row = batch.writeNextRow(); + + assertIllegalState("cannot call writeDecimal, actual type: Uuid", () -> row.writeDecimal("c1", dv)); + assertIllegalState("cannot call writeUuid, actual type: Decimal(22, 9)", () -> row.writeUuid("c2", uv)); + + BulkUpsertArrowData data = batch.buildBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + Assert.assertEquals("Schema", schema.toString()); + } + } + + private BulkUpsertArrowData createSimpleBatch() throws IOException { + try (ApacheArrowWriter writer = ApacheArrowWriter.newSchema() + .addColumn("pk", PrimitiveType.Int32) + .addNullableColumn("value", PrimitiveType.Text) + .createWriter(allocator)) { + + ApacheArrowWriter.Batch batch = writer.createNewBatch(2); + ApacheArrowWriter.Row row1 = batch.writeNextRow(); + ApacheArrowWriter.Row row2 = batch.writeNextRow(); + + row1.writeInt32("pk", 1); + row1.writeText("value", "value-1"); + + row2.writeInt32("pk", 2); + row2.writeText("value", "значение-2"); + + return batch.buildBatch(); + } + } + + @Test + public void readArrayBatchTest() throws IOException { + BulkUpsertArrowData data = createSimpleBatch(); + + Schema schema = readApacheArrowSchema(data.getSchema()); + try (VectorSchemaRoot vector = VectorSchemaRoot.create(schema, allocator)) { + // readApacheArrowBatch + try (InputStream is = data.getData().newInput()) { + try (ReadChannel channel = new ReadChannel(Channels.newChannel(is))) { + try (ArrowRecordBatch batch = MessageSerializer.deserializeRecordBatch(channel, allocator)) { + VectorLoader loader = new VectorLoader(vector); + loader.load(batch); + } + } + } + + IntVector pk = (IntVector) vector.getVector("pk"); + VarCharVector value = (VarCharVector) vector.getVector("value"); + + Assert.assertEquals(2, vector.getRowCount()); + Assert.assertEquals(1, pk.get(0)); + Assert.assertEquals(2, pk.get(1)); + Assert.assertEquals("value-1", new String(value.get(0), StandardCharsets.UTF_8)); + Assert.assertEquals("значение-2", new String(value.get(1), StandardCharsets.UTF_8)); + } + } + + private static Schema readApacheArrowSchema(ByteString bytes) { + try (InputStream is = bytes.newInput()) { + try (ReadChannel channel = new ReadChannel(Channels.newChannel(is))) { + return MessageSerializer.deserializeSchema(channel); + } + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } +}