Skip to content

Commit

Permalink
adding new code
Browse files Browse the repository at this point in the history
  • Loading branch information
datnt88 committed Feb 27, 2017
1 parent 37b9a0c commit a9f3ecb
Show file tree
Hide file tree
Showing 34 changed files with 310 additions and 2,068 deletions.
Binary file added .DS_Store
Binary file not shown.
Binary file added pHash/.DS_Store
Binary file not shown.
7 changes: 7 additions & 0 deletions pHash/compute_pHash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from PIL import Image
import imagehash
import sys,os

hash = imagehash.phash(Image.open(sys.argv[1]))
print(hash)

2,000 changes: 0 additions & 2,000 deletions sample/150311152434_cyclone_pam-15_20150326_vol-4.json

This file was deleted.

22 changes: 12 additions & 10 deletions src/main/java/hbku/qcri/sc/aidr/filtering/ImageFilter.java
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import java.io.IOException;

import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.transform.ImageTransform;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -36,14 +38,14 @@ public ImageFilter(String modelPath) {
}

//doing binary classification for one image
public boolean doFilteringImage(String collection_id, String im_file) {
public boolean doClassify(String collection_id, String im_file) {
File file = new File(im_file);
NativeImageLoader loader = new NativeImageLoader(height, width, channels);
try {
INDArray image = loader.asMatrix(file);
//DataNormalization scaler = new NormalizerStandardize();
//ImageTransform myTransform = new MyImageTransform(null, 121,121,122);
DataNormalization scaler = new ImagePreProcessingScaler(0,1);

//using the image scale normalization 0 -255
DataNormalization scaler = new ImagePreProcessingScaler(0,255);
scaler.transform(image);

// Pass through to neural Net
Expand All @@ -53,8 +55,8 @@ public boolean doFilteringImage(String collection_id, String im_file) {
//System.out.println("A pass through network took: " + (endTime - startTime)/1000 + " seconds"); //measure performance
float p_l1 = output.getFloat(0);
float p_l2 = output.getFloat(1);
System.out.println(p_l1);
System.out.println(p_l2);
//System.out.println(p_l1);
//System.out.println(p_l2);
if(p_l1 > p_l2){ // the classified label is NEG
return false;
}
Expand All @@ -68,10 +70,10 @@ public boolean doFilteringImage(String collection_id, String im_file) {
}

public static void main(String args[]){
String modelPath = "./gold_models/gold-model-alex-dl4j-ep-7.zip";
String modelPath = "./gold_models/alex-dl4j-ep-7.zip";
ImageFilter filter = new ImageFilter(modelPath);
System.out.println(filter.doFilteringImage("nepal_eq", "./im_data/POS/ecuador_eq_mild_im_89.jpg"));
System.out.println(filter.doFilteringImage("nepal_eq", "./im_data/POS/ecuador_eq_mild_im_89.jpg"));
System.out.println(filter.doFilteringImage("nepal_eq", "./im_data/POS/ecuador_eq_mild_im_89.jpg"));
System.out.println(filter.doClassify("nepal_eq", "./test_img/581045846810169344.jpg"));
System.out.println(filter.doClassify("nepal_eq", "./test_img/581046110665318400.png"));
System.out.println(filter.doClassify("nepal_eq", "./test_img/581046449544130560.jpg"));
}
}
110 changes: 63 additions & 47 deletions src/main/java/hbku/qcri/sc/aidr/filtering/ImageFilteringModel.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package hbku.qcri.sc.aidr.filtering;

import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.Options;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
Expand Down Expand Up @@ -27,6 +31,7 @@
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
Expand All @@ -45,67 +50,87 @@ public class ImageFilteringModel {
protected static int height = 227;
protected static int width = 227;
protected static int channels = 3;
protected static int numExamples = 5600;
protected static int numExamples = 4000;
protected static int numLabels = 2;
protected static int batchSize = 64;

protected static long seed = 42;
protected static long seed = 113;
protected static Random rng = new Random(seed);
protected static int listenerFreq = 1;
protected static int iterations = 2;
protected static int epochs = 10;
protected static int epochs = 30;
protected static double splitTrainTest = 0.75;
protected static int nCores = 6;
protected static int nCores = 4;
protected static boolean save = true;


protected static double l_rate = 1e-2;
protected static String modelType = "AlexNet"; // LeNet, AlexNet or Custom but you need to fill it out


public void run(String[] args) throws Exception {
//Nd4j.set

//Parsing parameters
CommandLineParser parser = new BasicParser();
Options options = new Options();
options.addOption("d", "im_data", true, "dataset folder");
options.addOption("n", "num_examples", true, "Number of examples");
options.addOption("b", "batch_size", true, "batch_size");
options.addOption("i", "iterations", true, "interation");
options.addOption("l", "iterations", true, "learing rate");
options.addOption("s", "saved_model", true, "saved_models");

CommandLine commandLine = parser.parse(options, args);

String dataDir = commandLine.getOptionValue('d', "z_data");
String savedDir = commandLine.getOptionValue('s', "saved_models");

numExamples = Integer.parseInt(commandLine.getOptionValue('n', "1000"));
batchSize = Integer.parseInt(commandLine.getOptionValue('b', "20"));
iterations = Integer.parseInt(commandLine.getOptionValue('i', "2"));
l_rate = Double.parseDouble(commandLine.getOptionValue('l', "1e-3"));


log.info("----------------------------------------");
log.info("Pamateters:");
log.info("Dataset folder: " + dataDir);
log.info("Saved model folder: " + savedDir);
log.info("Num of Examples: " + numExamples);
log.info("Batch size: " + batchSize);
log.info("Iterations: " + iterations);
log.info("Learning rate: " + l_rate);
log.info("Image size: " + height);
log.info("Normalization: 0 - 1") ;
log.info("----------------------------------------");
log.info("Load data....");

/**
* Data Setup -> organize and limit data file paths:
* - mainPath = path to image files
* - fileSplit = define basic dataset split with limits on format
* - pathFilter = define additional file load filter to limit size and balance batch content
**/
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
File mainPath = new File(System.getProperty("user.dir"), "im_data/");
File mainPath = new File(System.getProperty("user.dir"), dataDir);
FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng);
System.out.println(numExamples);
BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, 5600);

/**
* Data Setup -> train test split
* - inputSplit = define train and test split
**/
System.out.println(numExamples);
InputSplit[] inputSplit = fileSplit.sample(pathFilter, numExamples * splitTrainTest, numExamples * (1 - splitTrainTest));
//BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples, numLabels, numExamples/2);
BalancedPathFilter pathFilter = new BalancedPathFilter(rng, labelMaker, numExamples/2);

//- inputSplit = define train and test split
InputSplit[] inputSplit = fileSplit.sample(pathFilter, 75, 25);

InputSplit trainData = inputSplit[0];
InputSplit testData = inputSplit[1];

System.out.println("Num of train: " + trainData.length());
System.out.println("Num of test: " + testData.length());
log.info("Num of train: " + trainData.length());
log.info("Num of test: " + testData.length());

/**
* Data Setup -> normalization
* - how to normalize images and generate large dataset to train on
**/
//using the image scale normalization 0 -255
DataNormalization scaler = new ImagePreProcessingScaler(0,255);

ImageTransform cropTransform = new ScaleImageTransform(null, 14);
ImageTransform resizeTransform = new ResizeImageTransform(10, 14);
ImageTransform myTransform = new MyImageTransform(null, 121,121,122);

//DataNormalization scaler = new ImagePreProcessingScaler(0, 1 );
DataNormalization scaler = new NormalizerStandardize();

log.info("Build model....");

// Uncomment below to try AlexNet. Note change height and width to at least 100
//MultiLayerNetwork network = new AlexNet(height, width, channels, numLabels, seed, iterations).init();

log.info("Build model....");
//Using Alexnet
MultiLayerNetwork network = alexnetModel();
network.init();
Expand Down Expand Up @@ -143,9 +168,7 @@ public void run(String[] args) throws Exception {

String loc2save;
for( int i=1; i<epochs + 1; i++ ) {

network.fit(trainIter);

log.info("*** Completed epoch {} ***", i);
log.info("Evaluate model on dev set....");
Evaluation eval = new Evaluation(2);
Expand All @@ -159,7 +182,7 @@ public void run(String[] args) throws Exception {
log.info("--- Dev Rec: " + eval.recall());
log.info("--- Dev F1: " + eval.f1());
log.info("-------------------------------");
loc2save = "saved_models/alex-dl4j-final-ep-" + i + ".zip";
loc2save = savedDir + "/alex-dl4j-ep-" + i + ".zip";

//Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future
ModelSerializer.writeModel(network, loc2save, true);
Expand Down Expand Up @@ -193,7 +216,7 @@ private DenseLayer fullyConnected(String name, int out, double bias, double drop


public MultiLayerNetwork alexnetModel() {
/**
/**
* AlexNet model interpretation based on the original paper ImageNet Classification with Deep Convolutional Neural Networks
* and the imagenetExample code referenced.
* http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf
Expand All @@ -203,25 +226,24 @@ public MultiLayerNetwork alexnetModel() {
double dropOut = 0.5;

MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()

.seed(seed)
.weightInit(WeightInit.DISTRIBUTION)
.dist(new NormalDistribution(0.0, 0.01))
.activation(Activation.RELU)
.updater(Updater.NESTEROVS)
.updater(Updater.SGD)
.iterations(iterations)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) // normalize to prevent vanishing or exploding gradients
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.learningRate(1e-3)
.learningRate(l_rate)
.biasLearningRate(1e-2*2)
.learningRateDecayPolicy(LearningRatePolicy.Step)
.lrPolicyDecayRate(0.1)
.lrPolicySteps(10)
.lrPolicySteps(100000)
.regularization(true)
.l2(5 * 1e-4)
.momentum(0.9)
.miniBatch(false)
.list()
.list()
.layer(0, convInit("cnn1", channels, 96, new int[]{11, 11}, new int[]{4, 4}, new int[]{3, 3}, 0))
.layer(1, new LocalResponseNormalization.Builder().name("lrn1").build())
.layer(2, maxPool("maxpool1", new int[]{3,3}))
Expand All @@ -243,15 +265,9 @@ public MultiLayerNetwork alexnetModel() {
.pretrain(false)
.setInputType(InputType.convolutional(height, width, channels))
.build();
return new MultiLayerNetwork(conf);

}
return new MultiLayerNetwork(conf);

//Get prediction of an input image
public static Boolean getPrediction(String im_file, String model_path){
//Load the pretrained model
//Get prediction
return true;
}

public static void main(String[] args) {
Expand Down
80 changes: 80 additions & 0 deletions src/main/java/hbku/qcri/sc/aidr/imagecrawler/DoStat.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package hbku.qcri.sc.aidr.imagecrawler;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.cli.BasicParser;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import hbku.qcri.sc.aidr.imagecrawler.ImageCrawler;

public class DoStat {
protected static final Logger log = LoggerFactory.getLogger(DoStat.class);

public static void main(String args[]) throws ParseException, IOException {
CommandLineParser parser = new BasicParser();
Options options = new Options();
options.addOption("f", "folder", true, "Save image to the folder under the name of collection.");
options.addOption("j", "json_file", true, "Json file of the collection.");
options.addOption("m", "mode", true, "Run in parallel?");

CommandLine commandLine = null;

try{
commandLine = parser.parse(options, args);
}catch(ParseException ex){
log.error("Please provide command line options.");
log.error(ex.getMessage());
System.exit(0);
}
String json_file = commandLine.getOptionValue('j', "sample/150311152434_cyclone_pam-15_20150326_vol-4.json");

FileInputStream fstream = new FileInputStream(json_file);
BufferedReader br = new BufferedReader(new InputStreamReader(fstream));
List<String> urls = new ArrayList<String>();
List<Tweet> twts = new ArrayList<Tweet>();
Tweet twt = null;
String strLine;
String url = "";
int count = 0;
int twt_count = 0;

while ((strLine = br.readLine()) != null) {
// throw an exception for bad file here
twt_count +=1;
try{
twt = ImageCrawler.parseDataFeed(strLine); // Get the tweet
}
catch (Exception e){
log.error(json_file + " has a format problem...!");
}
if (twt != null) {
url = twt.imageURL;
if (url.trim() != "") {
count += 1;
if (!urls.contains(url)) {
urls.add(url);
twts.add(twt);
}
}
}
}
// json_file, num of tweet, num of urls, num of uniq urls
// running the dedupliation
// running the filtering

System.out.println(json_file +"," + twt_count + "," + count + "," + twts.size());
br.close();

}
}
Loading

0 comments on commit a9f3ecb

Please sign in to comment.