diff --git a/src/DataFrame/IO/Parquet/Encoding.hs b/src/DataFrame/IO/Parquet/Encoding.hs index 16c24e81..5d0afeb7 100644 --- a/src/DataFrame/IO/Parquet/Encoding.hs +++ b/src/DataFrame/IO/Parquet/Encoding.hs @@ -30,19 +30,24 @@ unpackBitPacked bw count bs BS.concatMap (\b -> BS.map (\i -> (b `shiftR` fromIntegral i) .&. 1) (BS.pack [0 .. 7])) chunk - toN = fst . BS.foldl' (\(a, i) b -> (a .|. (b `shiftL` i), i + 1)) (0, 0) + toN :: BS.ByteString -> Word32 + toN = + fst + . BS.foldl' + (\(a, i) b -> (a .|. (fromIntegral b `shiftL` i), i + 1)) + (0 :: Word32, 0 :: Int) - extractValues :: Int -> BS.ByteString -> BS.ByteString + extractValues :: Int -> BS.ByteString -> [Word32] extractValues n bitsLeft - | BS.null bitsLeft = BS.empty - | n <= 0 = BS.empty - | BS.length bitsLeft < bw = BS.empty + | BS.null bitsLeft = [] + | n <= 0 = [] + | BS.length bitsLeft < bw = [] | otherwise = let (this, bitsLeft') = BS.splitAt bw bitsLeft - in BS.cons (toN this) (extractValues (n - 1) bitsLeft') + in toN this : extractValues (n - 1) bitsLeft' vals = extractValues count bits - in (map fromIntegral (BS.unpack vals), rest) + in (vals, rest) decodeRLEBitPackedHybrid :: Int -> Int -> BS.ByteString -> ([Word32], BS.ByteString) diff --git a/src/DataFrame/IO/Parquet/Thrift.hs b/src/DataFrame/IO/Parquet/Thrift.hs index 7e91ce23..887b4fe7 100644 --- a/src/DataFrame/IO/Parquet/Thrift.hs +++ b/src/DataFrame/IO/Parquet/Thrift.hs @@ -811,21 +811,21 @@ readAesGcmCtrV1 v@(AesGcmCtrV1 aadPrefix aadFileUnique supplyAadPrefix) buf pos Just (elemType, identifier) -> case identifier of 1 -> do aadPrefix <- readByteString buf pos - readAesGcmCtrV1 (v{aadPrefix = aadPrefix}) buf pos lastFieldId + readAesGcmCtrV1 (v{aadPrefix = aadPrefix}) buf pos identifier 2 -> do aadFileUnique <- readByteString buf pos readAesGcmCtrV1 (v{aadFileUnique = aadFileUnique}) buf pos - lastFieldId + identifier 3 -> do supplyAadPrefix <- readAndAdvance pos buf readAesGcmCtrV1 (v{supplyAadPrefix = supplyAadPrefix == compactBooleanTrue}) buf pos - lastFieldId + identifier _ -> return ENCRYPTION_ALGORITHM_UNKNOWN readAesGcmCtrV1 _ _ _ _ = error "readAesGcmCtrV1 called with non AesGcmCtrV1" @@ -843,17 +843,17 @@ readAesGcmV1 v@(AesGcmV1 aadPrefix aadFileUnique supplyAadPrefix) buf pos lastFi Just (elemType, identifier) -> case identifier of 1 -> do aadPrefix <- readByteString buf pos - readAesGcmV1 (v{aadPrefix = aadPrefix}) buf pos lastFieldId + readAesGcmV1 (v{aadPrefix = aadPrefix}) buf pos identifier 2 -> do aadFileUnique <- readByteString buf pos - readAesGcmV1 (v{aadFileUnique = aadFileUnique}) buf pos lastFieldId + readAesGcmV1 (v{aadFileUnique = aadFileUnique}) buf pos identifier 3 -> do supplyAadPrefix <- readAndAdvance pos buf readAesGcmV1 (v{supplyAadPrefix = supplyAadPrefix == compactBooleanTrue}) buf pos - lastFieldId + identifier _ -> return ENCRYPTION_ALGORITHM_UNKNOWN readAesGcmV1 _ _ _ _ = error "readAesGcmV1 called with non AesGcmV1" @@ -1120,10 +1120,10 @@ readDecimalType precision scale buf pos lastFieldId = do Just (elemType, identifier) -> case identifier of 1 -> do scale' <- readInt32FromBuffer buf pos - readDecimalType precision scale' buf pos lastFieldId + readDecimalType precision scale' buf pos identifier 2 -> do precision' <- readInt32FromBuffer buf pos - readDecimalType precision' scale buf pos lastFieldId + readDecimalType precision' scale buf pos identifier _ -> error $ "UNKNOWN field ID for DecimalType" ++ show identifier readTimeType :: @@ -1136,15 +1136,14 @@ readTimeType :: readTimeType isAdjustedToUTC unit buf pos lastFieldId = do fieldContents <- readField buf pos lastFieldId case fieldContents of - Nothing -> return (TimeType isAdjustedToUTC unit) + Nothing -> return (TimeType{isAdjustedToUTC = isAdjustedToUTC, unit = unit}) Just (elemType, identifier) -> case identifier of 1 -> do - -- TODO: Check for empty isAdjustedToUTC' <- (== compactBooleanTrue) <$> readAndAdvance pos buf - readTimeType isAdjustedToUTC' unit buf pos lastFieldId + readTimeType isAdjustedToUTC' unit buf pos identifier 2 -> do unit' <- readUnit TIME_UNIT_UNKNOWN buf pos 0 - readTimeType isAdjustedToUTC unit' buf pos lastFieldId + readTimeType isAdjustedToUTC unit' buf pos identifier _ -> error $ "UNKNOWN field ID for TimeType" ++ show identifier readTimestampType :: @@ -1160,13 +1159,12 @@ readTimestampType isAdjustedToUTC unit buf pos lastFieldId = do Nothing -> return (TimestampType isAdjustedToUTC unit) Just (elemType, identifier) -> case identifier of 1 -> do - -- TODO: Check for empty isAdjustedToUTC' <- (== compactBooleanTrue) <$> readNoAdvance pos buf - readTimestampType False unit buf pos lastFieldId + readTimestampType False unit buf pos identifier 2 -> do - _ <- readField buf pos identifier + _ <- readField buf pos 0 unit' <- readUnit TIME_UNIT_UNKNOWN buf pos 0 - readTimestampType isAdjustedToUTC unit' buf pos lastFieldId + readTimestampType isAdjustedToUTC unit' buf pos identifier _ -> error $ "UNKNOWN field ID for TimestampType" ++ show identifier readUnit :: TimeUnit -> BS.ByteString -> IORef Int -> Int16 -> IO TimeUnit diff --git a/tests/Parquet.hs b/tests/Parquet.hs index 769e583c..f985c3c4 100644 --- a/tests/Parquet.hs +++ b/tests/Parquet.hs @@ -61,6 +61,83 @@ allTypesPlain = (unsafePerformIO (D.readParquet "./tests/data/alltypes_plain.parquet")) ) +allTypesTinyPagesDimensions :: Test +allTypesTinyPagesDimensions = + TestCase + ( assertEqual + "allTypesTinyPages last few" + (7300, 13) + ( unsafePerformIO + (fmap D.dimensions (D.readParquet "./tests/data/alltypes_tiny_pages.parquet")) + ) + ) + +tinyPagesLast10 :: D.DataFrame +tinyPagesLast10 = + D.fromNamedColumns + [ ("id", D.fromList @Int32 (reverse [6174 .. 6183])) + , ("bool_col", D.fromList @Bool (Prelude.take 10 (cycle [False, True]))) + , ("tinyint_col", D.fromList @Int32 [3, 2, 1, 0, 9, 8, 7, 6, 5, 4]) + , ("smallint_col", D.fromList @Int32 [3, 2, 1, 0, 9, 8, 7, 6, 5, 4]) + , ("int_col", D.fromList @Int32 [3, 2, 1, 0, 9, 8, 7, 6, 5, 4]) + , ("bigint_col", D.fromList @Int64 [30, 20, 10, 0, 90, 80, 70, 60, 50, 40]) + , + ( "float_col" + , D.fromList @Float [3.3, 2.2, 1.1, 0, 9.9, 8.8, 7.7, 6.6, 5.5, 4.4] + ) + , + ( "date_string_col" + , D.fromList @Text + [ "09/11/10" + , "09/11/10" + , "09/11/10" + , "09/11/10" + , "09/10/10" + , "09/10/10" + , "09/10/10" + , "09/10/10" + , "09/10/10" + , "09/10/10" + ] + ) + , + ( "string_col" + , D.fromList @Text ["3", "2", "1", "0", "9", "8", "7", "6", "5", "4"] + ) + , + ( "timestamp_col" + , D.fromList @UTCTime + [ UTCTime (fromGregorian 2010 9 10) (secondsToDiffTime 85384) + , UTCTime (fromGregorian 2010 9 10) (secondsToDiffTime 85324) + , UTCTime (fromGregorian 2010 9 10) (secondsToDiffTime 85264) + , UTCTime (fromGregorian 2010 9 10) (secondsToDiffTime 85204) + , UTCTime (fromGregorian 2010 9 9) (secondsToDiffTime 85144) + , UTCTime (fromGregorian 2010 9 9) (secondsToDiffTime 85084) + , UTCTime (fromGregorian 2010 9 9) (secondsToDiffTime 85024) + , UTCTime (fromGregorian 2010 9 9) (secondsToDiffTime 84964) + , UTCTime (fromGregorian 2010 9 9) (secondsToDiffTime 84904) + , UTCTime (fromGregorian 2010 9 9) (secondsToDiffTime 84844) + ] + ) + , ("year", D.fromList @Int32 (replicate 10 2010)) + , ("month", D.fromList @Int32 (replicate 10 9)) + ] + +allTypesTinyPagesLastFew :: Test +allTypesTinyPagesLastFew = + TestCase + ( assertEqual + "allTypesTinyPages dimensions" + tinyPagesLast10 + ( unsafePerformIO + -- Excluding doubles because they are weird to compare. + ( fmap + (D.takeLast 10 . D.exclude ["double_col"]) + (D.readParquet "./tests/data/alltypes_tiny_pages.parquet") + ) + ) + ) + allTypesPlainSnappy :: Test allTypesPlainSnappy = TestCase @@ -537,7 +614,12 @@ mtCars = (unsafePerformIO (D.readParquet "./tests/data/mtcars.parquet")) ) --- Uncomment to run parquet tests. --- Currently commented because they don't run with github CI tests :: [Test] -tests = [allTypesPlain, allTypesPlainSnappy, allTypesDictionary, mtCars] +tests = + [ allTypesPlain + , allTypesPlainSnappy + , allTypesDictionary + , mtCars + , allTypesTinyPagesLastFew + , allTypesTinyPagesDimensions + ]