Skip to content

Commit

Permalink
Add keras model for image classification example and api change (#4098)
Browse files Browse the repository at this point in the history
  • Loading branch information
dding3 authored Feb 24, 2022
1 parent 89a39fa commit fc4b6bd
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2016 The BigDL Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.intel.analytics.bigdl.dllib.example.keras

import com.intel.analytics.bigdl.dllib.NNContext
import com.intel.analytics.bigdl.dllib.feature.image.ImageChannelNormalize
import com.intel.analytics.bigdl.dllib.nnframes.NNImageReader
import com.intel.analytics.bigdl.dllib.keras.layers._
import com.intel.analytics.bigdl.dllib.utils.Shape
import com.intel.analytics.bigdl.dllib.keras.Sequential
import com.intel.analytics.bigdl.dllib.keras.objectives.BinaryCrossEntropy
import com.intel.analytics.bigdl.dllib.optim._
import com.intel.analytics.bigdl.dllib.models.lenet.Utils._
import com.intel.analytics.bigdl.numeric.NumericFloat
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._

object ImageClassification {
def buildMode(inputShape: Shape): Sequential[Float] = {
import com.intel.analytics.bigdl.numeric.NumericFloat
val model = Sequential()
model.add(Conv2D(32, 3, 3, inputShape = inputShape))
model.add(Activation("relu"))
model.add(MaxPooling2D(poolSize = (2, 2)))

model.add(Conv2D(32, 3, 3))
model.add(Activation("relu"))
model.add(MaxPooling2D(poolSize = (2, 2)))

model.add(Conv2D(64, 3, 3))
model.add(Activation("relu"))
model.add(MaxPooling2D(poolSize = (2, 2)))

model.add(Flatten())
model.add(Dense(64))
model.add(Activation("relu"))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation("sigmoid"))

return model
}

def main(args: Array[String]): Unit = {
trainParser.parse(args, new TrainParams()).map(param => {
val sc = NNContext.initNNContext()

val createLabel = udf { row: Row =>
if (new Path(row.getString(0)).getName.contains("cat")) 1 else 2
}
val imgDF = NNImageReader.readImages(param.folder, sc, resizeH = 150, resizeW = 150)
.withColumn("label", createLabel(col("image")))
val Array(validationDF, trainingDF) = imgDF.randomSplit(Array(0.1, 0.9), seed = 42L)

val transformers = ImageChannelNormalize(0, 0, 0, 255, 255, 255)
val model = buildMode(Shape(3, 150, 150))

val optimMethod = new RMSprop[Float]()

model.compile(optimizer = optimMethod,
loss = BinaryCrossEntropy[Float](),
metrics = List(new Top1Accuracy[Float]()))
model.fit(trainingDF, batchSize = param.batchSize, nbEpoch = param.maxEpoch,
labelCols = Array("label"), transform = transformers, valX = validationDF)

sc.stop()
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -577,15 +577,16 @@ abstract class KerasNet[T](implicit val tag: ClassTag[T], implicit val ev: Tenso
x: DataFrame,
batchSize: Int,
nbEpoch: Int,
labelCol: String,
labelCols: Array[String],
transform: ImageProcessing,
valX: DataFrame)(implicit ev: TensorNumeric[T]): Unit = {
val trainData = df2ImageSet(x, labelCol, transform)
require(labelCols.length == 1, "current only support one label for dataframe of image")
val trainData = df2ImageSet(x, labelCols.head, transform)
val transformer2 = ImageMatToTensor[Float]() -> ImageSetToSample[Float]()
trainData.transform(transformer2)

val valData = if (valX != null) {
val valSet = df2ImageSet(valX, labelCol, transform)
val valSet = df2ImageSet(valX, labelCols.head, transform)
valSet.transform(transformer2)
valSet
} else null
Expand All @@ -597,9 +598,9 @@ abstract class KerasNet[T](implicit val tag: ClassTag[T], implicit val ev: Tenso
x: DataFrame,
batchSize: Int,
nbEpoch: Int,
labelCol: String,
labelCols: Array[String],
transform: ImageProcessing)(implicit ev: TensorNumeric[T]): Unit = {
this.fit(x, batchSize, nbEpoch, labelCol, transform, null)
this.fit(x, batchSize, nbEpoch, labelCols, transform, null)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.intel.analytics.bigdl.dllib.keras.models

import com.intel.analytics.bigdl.dllib.feature.dataset.{LocalDataSet, MiniBatch, Sample}
import com.intel.analytics.bigdl.dllib.optim.{Loss, SGD, Top1Accuracy, Top5Accuracy}
import com.intel.analytics.bigdl.dllib.optim._
import com.intel.analytics.bigdl.dllib.utils.python.api.PythonBigDL
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.dllib.tensor.{Storage, Tensor}
Expand All @@ -32,7 +32,7 @@ import com.intel.analytics.bigdl.dllib.keras.{Sequential, ZooSpecHelper}
import com.intel.analytics.bigdl.dllib.keras.layers._
import com.intel.analytics.bigdl.dllib.keras.models.Sequential
import com.intel.analytics.bigdl.dllib.keras.models.Model
import com.intel.analytics.bigdl.dllib.keras.objectives.ZooClassNLLCriterion
import com.intel.analytics.bigdl.dllib.keras.objectives.{BinaryCrossEntropy, ZooClassNLLCriterion}
import com.intel.analytics.bigdl.dllib.keras.python.PythonZooKeras
import com.intel.analytics.bigdl.dllib.nn.{CosineEmbeddingCriterion, MSECriterion, MarginRankingCriterion, ParallelCriterion}
import com.intel.analytics.bigdl.dllib.nnframes.{NNEstimatorSpec, NNImageReader}
Expand Down Expand Up @@ -449,7 +449,8 @@ class TrainingSpec extends ZooSpecHelper {
model.add(Reshape[Float](Array(169)))
model.add(Dense[Float](2, activation = "log_softmax"))
model.compile(optimizer = new SGD[Float](), loss = ZooClassNLLCriterion[Float]())
model.fit(imgDF, batchSize = 1, nbEpoch = 1, labelCol = "label", transform = transformers)
model.fit(imgDF, batchSize = 1, nbEpoch = 1, labelCols = Array("label"),
transform = transformers)
val predDf = model.predict(imgDF, predictionCol = "predict", transform = transformers)
predDf.show()
model.evaluate(imgDF, batchSize = 1, labelCol = "label", transform = transformers)
Expand Down

0 comments on commit fc4b6bd

Please sign in to comment.