diff --git a/.gitignore b/.gitignore index 26ae9a84..b5e3cf28 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.log *.pyc build/*.jar +.coverage docs/_site docs/api @@ -18,6 +19,9 @@ src_managed/ project/boot/ project/plugins/project/ +# spark +metastore_db + # intellij .idea/ diff --git a/python/sparkdl/image/imageIO.py b/python/sparkdl/image/imageIO.py index 03c101cc..5cb0aba6 100644 --- a/python/sparkdl/image/imageIO.py +++ b/python/sparkdl/image/imageIO.py @@ -34,6 +34,10 @@ StructField("nChannels", IntegerType(), False), StructField("data", BinaryType(), False)]) +gifSchema = StructType([StructField("filePath", StringType(), False), + StructField("frameNum", IntegerType(), True), + StructField("gifFrame", imageSchema, True)]) + # ImageType class for holding metadata about images stored in DataFrames. # fields: @@ -199,23 +203,77 @@ def _decodeImage(imageData): image = imageArrayToStruct(imgArray, mode.sparkMode) return image + +def _decodeGif(gifData): + """ + Decode compressed GIF data into a sequence of images. + + :param gifData: (bytes, bytearray) compressed GIF data in PIL compatible format. + :return: list of tuples of zero-indexed numbers and + DataFrame Rows of image structs: (idx, struct) + """ + try: + img = Image.open(BytesIO(gifData)) + except IOError: + return [(None, None)] + + if img.format.lower() == "gif": + mode = pilModeLookup["RGB"] + else: + warn("Image file does not appear to be a GIF") + return [(None, None)] + + frames = [] + i = 0 + mypalette = img.getpalette() + try: + while True: + if not img.getpalette() and mypalette: + img.putpalette(mypalette) + newImg = Image.new("RGB", img.size) + newImg.paste(img) + + newImgArray = np.asarray(newImg) + newImage = imageArrayToStruct(newImgArray, mode.sparkMode) + frames.append((i, newImage)) + + i += 1 + img.seek(img.tell() + 1) + except EOFError: + # end of sequence + pass + + return frames + # Creating a UDF on import can cause SparkContext issues sometimes. # decodeImage = udf(_decodeImage, imageSchema) +def filesToRDD(sc, path, numPartitions=None): + """ + Read files from a directory to an RDD. + + :param sc: SparkContext. + :param path: str, path to files. + :param numPartitions: int, number or partitions to use for reading files. + :return: RDD tuple of: (filePath: str, fileData: BinaryType) + """ + numPartitions = numPartitions or sc.defaultParallelism + rdd = sc.binaryFiles(path, minPartitions=numPartitions).repartition(numPartitions) + return rdd.map(lambda x: (x[0], bytearray(x[1]))) + + def filesToDF(sc, path, numPartitions=None): """ Read files from a directory to a DataFrame. :param sc: SparkContext. :param path: str, path to files. - :param numPartition: int, number or partitions to use for reading files. + :param numPartitions: int, number or partitions to use for reading files. :return: DataFrame, with columns: (filePath: str, fileData: BinaryType) """ - numPartitions = numPartitions or sc.defaultParallelism schema = StructType([StructField("filePath", StringType(), False), StructField("fileData", BinaryType(), False)]) - rdd = sc.binaryFiles(path, minPartitions=numPartitions).repartition(numPartitions) - rdd = rdd.map(lambda x: (x[0], bytearray(x[1]))) + rdd = filesToRDD(sc, path, numPartitions) return rdd.toDF(schema) @@ -235,3 +293,21 @@ def _readImages(imageDirectory, numPartition, sc): decodeImage = udf(_decodeImage, imageSchema) imageData = filesToDF(sc, imageDirectory, numPartitions=numPartition) return imageData.select("filePath", decodeImage("fileData").alias("image")) + + +def readGifs(gifDirectory, numPartition=None): + """ + Read a directory of GIFs (or a single GIF) into a DataFrame. + + :param sc: spark context + :param gifDirectory: str, file path. + :param numPartition: int, number or partitions to use for reading files. + :return: DataFrame, with columns: (filepath: str, image: imageSchema). + """ + return _readGifs(gifDirectory, numPartition, SparkContext.getOrCreate()) + + +def _readGifs(gifDirectory, numPartition, sc): + gifsRDD = filesToRDD(sc, gifDirectory, numPartitions=numPartition) + framesRDD = gifsRDD.flatMap(lambda x: [(x[0], i, frame) for (i, frame) in _decodeGif(x[1])]) + return framesRDD.toDF(gifSchema) diff --git a/python/tests/image/test_imageIO.py b/python/tests/image/test_imageIO.py index 0b6f6e61..9f47b72b 100644 --- a/python/tests/image/test_imageIO.py +++ b/python/tests/image/test_imageIO.py @@ -26,7 +26,7 @@ from sparkdl.image import imageIO from ..tests import SparkDLTestCase -# Create dome fake image data to work with +# Create some fake image data to work with def create_image_data(): # Random image-like data array = np.random.randint(0, 256, (10, 11, 3), 'uint8') @@ -173,4 +173,98 @@ def test_filesTODF(self): self.assertEqual(type(first.fileData), bytearray) +# Create some fake GIF data to work with +def create_gif_data(): + # Random GIF-like data + arrays2D = [np.random.randint(0, 256, (10, 11), 'uint8') for _ in range(3)] + arrays3D = [np.dstack((a, a, a)) for a in arrays2D] + # Create frames in P mode because Pillow always reads GIFs as P or L images + frames = [PIL.Image.fromarray(a, mode='P') for a in arrays2D] + + # Compress as GIF + gifFile = BytesIO() + frames[0].save(gifFile, 'gif', save_all=True, append_images=frames[1:], optimize=False) + gifFile.seek(0) + + # Get GIF data as stream + gifData = gifFile.read() + return arrays3D, gifData + +gifArray, gifData = create_gif_data() +frameArray = gifArray[0] + + +class BinaryGifFilesMock(object): + + defaultParallelism = 4 + + def __init__(self, sc): + self.sc = sc + + def binaryFiles(self, path, minPartitions=None): + gifsData = [["file/path", gifData], + ["another/file/path", gifData], + ["bad/gif", b"badGifData"] + ] + rdd = self.sc.parallelize(gifsData) + if minPartitions is not None: + rdd = rdd.repartition(minPartitions) + return rdd + + +class TestReadGifs(SparkDLTestCase): + @classmethod + def setUpClass(cls): + super(TestReadGifs, cls).setUpClass() + cls.binaryFilesMock = BinaryGifFilesMock(cls.sc) + + @classmethod + def tearDownClass(cls): + super(TestReadGifs, cls).tearDownClass() + cls.binaryFilesMock = None + + def test_decodeGif(self): + badFrames = imageIO._decodeGif(b"xxx") + self.assertEqual(badFrames, [(None, None)]) + gifFrames = imageIO._decodeGif(gifData) + self.assertIsNotNone(gifFrames) + self.assertEqual(len(gifFrames), 3) + self.assertEqual(len(gifFrames[0][1]), len(imageIO.imageSchema.names)) + for n in imageIO.imageSchema.names: + gifFrames[0][1][n] + + def test_gif_round_trip(self): + # Test round trip: array -> GIF frame -> sparkImg -> array + binarySchema = StructType([StructField("data", BinaryType(), False)]) + df = self.session.sparkContext.parallelize([bytearray(gifData)]) + + # Convert to GIF frames + rdd = df.flatMap(lambda x: [f[1] for f in imageIO._decodeGif(x)]) + framesDF = rdd.toDF(imageIO.imageSchema) + row = framesDF.first() + + testArray = imageIO.imageStructToArray(row) + self.assertEqual(testArray.shape, frameArray.shape) + self.assertEqual(testArray.dtype, frameArray.dtype) + self.assertTrue(np.all(frameArray == testArray)) + + def test_readGifs(self): + # Test that reading + gifDF = imageIO._readGifs("some/path", 2, self.binaryFilesMock) + self.assertTrue("filePath" in gifDF.schema.names) + self.assertTrue("frameNum" in gifDF.schema.names) + self.assertTrue("gifFrame" in gifDF.schema.names) + + # The DF should have 6 images (2 images, 3 frames each) and 1 null. + self.assertEqual(gifDF.count(), 7) + validGifs = gifDF.filter(col("gifFrame").isNotNull()) + self.assertEqual(validGifs.count(), 6) + + frame = validGifs.first().gifFrame + self.assertEqual(frame.height, frameArray.shape[0]) + self.assertEqual(frame.width, frameArray.shape[1]) + self.assertEqual(imageIO.imageType(frame).nChannels, frameArray.shape[2]) + self.assertEqual(frame.data, frameArray.tobytes()) + + # TODO: make unit tests for arrayToImageRow on arrays of varying shapes, channels, dtypes.