Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/DataFrame/IO/Parquet/Encoding.hs
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,24 @@ unpackBitPacked bw count bs
BS.concatMap
(\b -> BS.map (\i -> (b `shiftR` fromIntegral i) .&. 1) (BS.pack [0 .. 7]))
chunk
toN = fst . BS.foldl' (\(a, i) b -> (a .|. (b `shiftL` i), i + 1)) (0, 0)
toN :: BS.ByteString -> Word32
toN =
fst
. BS.foldl'
(\(a, i) b -> (a .|. (fromIntegral b `shiftL` i), i + 1))
(0 :: Word32, 0 :: Int)

extractValues :: Int -> BS.ByteString -> BS.ByteString
extractValues :: Int -> BS.ByteString -> [Word32]
extractValues n bitsLeft
| BS.null bitsLeft = BS.empty
| n <= 0 = BS.empty
| BS.length bitsLeft < bw = BS.empty
| BS.null bitsLeft = []
| n <= 0 = []
| BS.length bitsLeft < bw = []
| otherwise =
let (this, bitsLeft') = BS.splitAt bw bitsLeft
in BS.cons (toN this) (extractValues (n - 1) bitsLeft')
in toN this : extractValues (n - 1) bitsLeft'

vals = extractValues count bits
in (map fromIntegral (BS.unpack vals), rest)
in (vals, rest)

decodeRLEBitPackedHybrid ::
Int -> Int -> BS.ByteString -> ([Word32], BS.ByteString)
Expand Down
30 changes: 14 additions & 16 deletions src/DataFrame/IO/Parquet/Thrift.hs
Original file line number Diff line number Diff line change
Expand Up @@ -811,21 +811,21 @@ readAesGcmCtrV1 v@(AesGcmCtrV1 aadPrefix aadFileUnique supplyAadPrefix) buf pos
Just (elemType, identifier) -> case identifier of
1 -> do
aadPrefix <- readByteString buf pos
readAesGcmCtrV1 (v{aadPrefix = aadPrefix}) buf pos lastFieldId
readAesGcmCtrV1 (v{aadPrefix = aadPrefix}) buf pos identifier
2 -> do
aadFileUnique <- readByteString buf pos
readAesGcmCtrV1
(v{aadFileUnique = aadFileUnique})
buf
pos
lastFieldId
identifier
3 -> do
supplyAadPrefix <- readAndAdvance pos buf
readAesGcmCtrV1
(v{supplyAadPrefix = supplyAadPrefix == compactBooleanTrue})
buf
pos
lastFieldId
identifier
_ -> return ENCRYPTION_ALGORITHM_UNKNOWN
readAesGcmCtrV1 _ _ _ _ =
error "readAesGcmCtrV1 called with non AesGcmCtrV1"
Expand All @@ -843,17 +843,17 @@ readAesGcmV1 v@(AesGcmV1 aadPrefix aadFileUnique supplyAadPrefix) buf pos lastFi
Just (elemType, identifier) -> case identifier of
1 -> do
aadPrefix <- readByteString buf pos
readAesGcmV1 (v{aadPrefix = aadPrefix}) buf pos lastFieldId
readAesGcmV1 (v{aadPrefix = aadPrefix}) buf pos identifier
2 -> do
aadFileUnique <- readByteString buf pos
readAesGcmV1 (v{aadFileUnique = aadFileUnique}) buf pos lastFieldId
readAesGcmV1 (v{aadFileUnique = aadFileUnique}) buf pos identifier
3 -> do
supplyAadPrefix <- readAndAdvance pos buf
readAesGcmV1
(v{supplyAadPrefix = supplyAadPrefix == compactBooleanTrue})
buf
pos
lastFieldId
identifier
_ -> return ENCRYPTION_ALGORITHM_UNKNOWN
readAesGcmV1 _ _ _ _ =
error "readAesGcmV1 called with non AesGcmV1"
Expand Down Expand Up @@ -1120,10 +1120,10 @@ readDecimalType precision scale buf pos lastFieldId = do
Just (elemType, identifier) -> case identifier of
1 -> do
scale' <- readInt32FromBuffer buf pos
readDecimalType precision scale' buf pos lastFieldId
readDecimalType precision scale' buf pos identifier
2 -> do
precision' <- readInt32FromBuffer buf pos
readDecimalType precision' scale buf pos lastFieldId
readDecimalType precision' scale buf pos identifier
_ -> error $ "UNKNOWN field ID for DecimalType" ++ show identifier

readTimeType ::
Expand All @@ -1136,15 +1136,14 @@ readTimeType ::
readTimeType isAdjustedToUTC unit buf pos lastFieldId = do
fieldContents <- readField buf pos lastFieldId
case fieldContents of
Nothing -> return (TimeType isAdjustedToUTC unit)
Nothing -> return (TimeType{isAdjustedToUTC = isAdjustedToUTC, unit = unit})
Just (elemType, identifier) -> case identifier of
1 -> do
-- TODO: Check for empty
isAdjustedToUTC' <- (== compactBooleanTrue) <$> readAndAdvance pos buf
readTimeType isAdjustedToUTC' unit buf pos lastFieldId
readTimeType isAdjustedToUTC' unit buf pos identifier
2 -> do
unit' <- readUnit TIME_UNIT_UNKNOWN buf pos 0
readTimeType isAdjustedToUTC unit' buf pos lastFieldId
readTimeType isAdjustedToUTC unit' buf pos identifier
_ -> error $ "UNKNOWN field ID for TimeType" ++ show identifier

readTimestampType ::
Expand All @@ -1160,13 +1159,12 @@ readTimestampType isAdjustedToUTC unit buf pos lastFieldId = do
Nothing -> return (TimestampType isAdjustedToUTC unit)
Just (elemType, identifier) -> case identifier of
1 -> do
-- TODO: Check for empty
isAdjustedToUTC' <- (== compactBooleanTrue) <$> readNoAdvance pos buf
readTimestampType False unit buf pos lastFieldId
readTimestampType False unit buf pos identifier
2 -> do
_ <- readField buf pos identifier
_ <- readField buf pos 0
unit' <- readUnit TIME_UNIT_UNKNOWN buf pos 0
readTimestampType isAdjustedToUTC unit' buf pos lastFieldId
readTimestampType isAdjustedToUTC unit' buf pos identifier
_ -> error $ "UNKNOWN field ID for TimestampType" ++ show identifier

readUnit :: TimeUnit -> BS.ByteString -> IORef Int -> Int16 -> IO TimeUnit
Expand Down
88 changes: 85 additions & 3 deletions tests/Parquet.hs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,83 @@ allTypesPlain =
(unsafePerformIO (D.readParquet "./tests/data/alltypes_plain.parquet"))
)

allTypesTinyPagesDimensions :: Test
allTypesTinyPagesDimensions =
TestCase
( assertEqual
"allTypesTinyPages last few"
(7300, 13)
( unsafePerformIO
(fmap D.dimensions (D.readParquet "./tests/data/alltypes_tiny_pages.parquet"))
)
)

tinyPagesLast10 :: D.DataFrame
tinyPagesLast10 =
D.fromNamedColumns
[ ("id", D.fromList @Int32 (reverse [6174 .. 6183]))
, ("bool_col", D.fromList @Bool (Prelude.take 10 (cycle [False, True])))
, ("tinyint_col", D.fromList @Int32 [3, 2, 1, 0, 9, 8, 7, 6, 5, 4])
, ("smallint_col", D.fromList @Int32 [3, 2, 1, 0, 9, 8, 7, 6, 5, 4])
, ("int_col", D.fromList @Int32 [3, 2, 1, 0, 9, 8, 7, 6, 5, 4])
, ("bigint_col", D.fromList @Int64 [30, 20, 10, 0, 90, 80, 70, 60, 50, 40])
,
( "float_col"
, D.fromList @Float [3.3, 2.2, 1.1, 0, 9.9, 8.8, 7.7, 6.6, 5.5, 4.4]
)
,
( "date_string_col"
, D.fromList @Text
[ "09/11/10"
, "09/11/10"
, "09/11/10"
, "09/11/10"
, "09/10/10"
, "09/10/10"
, "09/10/10"
, "09/10/10"
, "09/10/10"
, "09/10/10"
]
)
,
( "string_col"
, D.fromList @Text ["3", "2", "1", "0", "9", "8", "7", "6", "5", "4"]
)
,
( "timestamp_col"
, D.fromList @UTCTime
[ UTCTime (fromGregorian 2010 9 10) (secondsToDiffTime 85384)
, UTCTime (fromGregorian 2010 9 10) (secondsToDiffTime 85324)
, UTCTime (fromGregorian 2010 9 10) (secondsToDiffTime 85264)
, UTCTime (fromGregorian 2010 9 10) (secondsToDiffTime 85204)
, UTCTime (fromGregorian 2010 9 9) (secondsToDiffTime 85144)
, UTCTime (fromGregorian 2010 9 9) (secondsToDiffTime 85084)
, UTCTime (fromGregorian 2010 9 9) (secondsToDiffTime 85024)
, UTCTime (fromGregorian 2010 9 9) (secondsToDiffTime 84964)
, UTCTime (fromGregorian 2010 9 9) (secondsToDiffTime 84904)
, UTCTime (fromGregorian 2010 9 9) (secondsToDiffTime 84844)
]
)
, ("year", D.fromList @Int32 (replicate 10 2010))
, ("month", D.fromList @Int32 (replicate 10 9))
]

allTypesTinyPagesLastFew :: Test
allTypesTinyPagesLastFew =
TestCase
( assertEqual
"allTypesTinyPages dimensions"
tinyPagesLast10
( unsafePerformIO
-- Excluding doubles because they are weird to compare.
( fmap
(D.takeLast 10 . D.exclude ["double_col"])
(D.readParquet "./tests/data/alltypes_tiny_pages.parquet")
)
)
)

allTypesPlainSnappy :: Test
allTypesPlainSnappy =
TestCase
Expand Down Expand Up @@ -537,7 +614,12 @@ mtCars =
(unsafePerformIO (D.readParquet "./tests/data/mtcars.parquet"))
)

-- Uncomment to run parquet tests.
-- Currently commented because they don't run with github CI
tests :: [Test]
tests = [allTypesPlain, allTypesPlainSnappy, allTypesDictionary, mtCars]
tests =
[ allTypesPlain
, allTypesPlainSnappy
, allTypesDictionary
, mtCars
, allTypesTinyPagesLastFew
, allTypesTinyPagesDimensions
]