Skip to content

Commit

Permalink
Merge pull request tensorflow#4503 from martinwicke/branch_133795652
Browse files Browse the repository at this point in the history
Branch 133795652
  • Loading branch information
martinwicke authored Sep 21, 2016
2 parents 640353d + 6e0c1b8 commit 754048a
Show file tree
Hide file tree
Showing 104 changed files with 3,503 additions and 1,300 deletions.
1 change: 1 addition & 0 deletions tensorflow/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ filegroup(
"//tensorflow/c:all_files",
"//tensorflow/cc:all_files",
"//tensorflow/contrib:all_files",
"//tensorflow/contrib/android:all_files",
"//tensorflow/contrib/bayesflow:all_files",
"//tensorflow/contrib/copy_graph:all_files",
"//tensorflow/contrib/cudnn_rnn:all_files",
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/c/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,10 @@ extern "C" {

struct TF_Graph {
TF_Graph()
: graph(OpRegistry::Global()), num_sessions(0), delete_requested(false) {}
: graph(OpRegistry::Global()),
refiner(graph.op_registry()),
num_sessions(0),
delete_requested(false) {}
mutex mu;
Graph graph GUARDED_BY(mu);

Expand Down
5 changes: 3 additions & 2 deletions tensorflow/cc/framework/scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ Scope::Scope(Graph* graph, Status* status, Scope::NameMap* name_map,
scope_used_(nullptr) {}

Scope Scope::NewRootScope() {
return Scope(new Graph(OpRegistry::Global()), new Status, new Scope::NameMap,
new ShapeRefiner);
Graph* graph = new Graph(OpRegistry::Global());
ShapeRefiner* refiner = new ShapeRefiner(graph->op_registry());
return Scope(graph, new Status, new Scope::NameMap, refiner);
}

Scope::Scope(const Scope& other, Scope::Tags::ScopeName, const string& name,
Expand Down
59 changes: 59 additions & 0 deletions tensorflow/contrib/android/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Description:
# JNI-based Java inference interface for TensorFlow.

package(default_visibility = ["//visibility:public"])

licenses(["notice"]) # Apache 2.0

exports_files(["LICENSE"])

load(
"//tensorflow:tensorflow.bzl",
"tf_copts",
"if_android",
)

# TODO(andrewharp): Make this an android_library or java_library.
filegroup(
name = "android_tensorflow_inference_java_srcs",
srcs = glob(["java/**/*.java"]),
visibility = ["//visibility:public"],
)

exports_files([
"jni/version_script.lds",
])

filegroup(
name = "android_tensorflow_inference_jni_srcs",
srcs = glob([
"jni/**/*.cc",
"jni/**/*.h",
]),
visibility = ["//visibility:public"],
)

cc_library(
name = "android_tensorflow_inference_jni",
srcs = if_android([":android_tensorflow_inference_jni_srcs"]),
copts = tf_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:android_tensorflow_lib_lite",
],
alwayslink = 1,
)

filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
"bin/**",
"gen/**",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

package org.tensorflow.contrib.android;

import android.content.res.AssetManager;

import java.util.Random;

/**
* JNI wrapper class for the Tensorflow native code.
*
* See tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java
* for an example usage.
* */
public class TensorFlowInferenceInterface {
/**
* A unique identifier used to associate the Java TensorFlowInferenceInterface
* with its associated native variables.
* It is accessed via native reflection so any refactoring must also be accompanied
* by a change to tensorflow_inference_jni.cc.
*/
private final long id;

public TensorFlowInferenceInterface() {
id = new Random().nextLong();
}

/**
* Creates a native TensorFlow session for the given model.
*
* @param assetManager The AssetManager to use to load the model file.
* @param model The filepath to the GraphDef proto representing the model.
* @return The native status returned by TensorFlow. 0 indicates success.
*/
public native int initializeTensorFlow(AssetManager assetManager, String model);

/**
* Runs inference between the previously registered input nodes (via fillNode*)
* and the requested output nodes. Output nodes can then be queried with the
* readNode* methods.
*
* @param outputNames A list of output nodes which should be filled by the inference pass.
* @return The native status returned by TensorFlow. 0 indicates success.
*/
public native int runInference(String[] outputNames);

/**
* Cleans up the native variables associated with this Object. initializeTensorFlow() can then
* be called again to initialize a new session.
*
*/
public native void close();

// Methods for creating a native Tensor and filling it with values.
public native void fillNodeFloat(String inputName, int x, int y, int z, int d, float[] values);
public native void fillNodeInt(String inputName, int x, int y, int z, int d, int[] values);
public native void fillNodeDouble(String inputName, int x, int y, int z, int d, double[] values);

public native void readNodeFloat(String outputName, float[] values);
public native void readNodeInt(String outputName, int[] values);
public native void readNodeDouble(String outputName, double[] values);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/examples/android/jni/jni_utils.h"
#include "tensorflow/contrib/android/jni/jni_utils.h"

#include <android/asset_manager.h>
#include <android/asset_manager_jni.h>
Expand All @@ -29,8 +29,8 @@ limitations under the License.
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/io/zero_copy_stream_impl_lite.h"
#include "google/protobuf/message_lite.h"
#include "tensorflow/contrib/android/jni/limiting_file_input_stream.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/examples/android/jni/limiting_file_input_stream.h"

static const char* const ASSET_PREFIX = "file:///android_asset/";

Expand Down Expand Up @@ -75,9 +75,9 @@ bool IsAsset(const char* const filename) {
return strstr(filename, ASSET_PREFIX) == filename;
}

void ReadFileToProto(AAssetManager* const asset_manager,
const char* const filename,
google::protobuf::MessageLite* message) {
void ReadFileToProtoOrDie(AAssetManager* const asset_manager,
const char* const filename,
google::protobuf::MessageLite* message) {
if (!IsAsset(filename)) {
VLOG(0) << "Opening file: " << filename;
CHECK(PortableReadFileToProto(filename, message));
Expand Down Expand Up @@ -125,57 +125,15 @@ void ReadFileToProto(AAssetManager* const asset_manager,
AAsset_close(asset);
}

void ReadFileToString(AAssetManager* const asset_manager,
const char* const filename, std::string* str) {
if (!IsAsset(filename)) {
VLOG(0) << "Opening file: " << filename;
std::ifstream t(filename);
std::string tmp((std::istreambuf_iterator<char>(t)),
std::istreambuf_iterator<char>());
tmp.swap(*str);
t.close();
return;
}

CHECK_NOTNULL(asset_manager);
const char* const asset_filename = filename + strlen(ASSET_PREFIX);
AAsset* asset =
AAssetManager_open(asset_manager, asset_filename, AASSET_MODE_STREAMING);
CHECK_NOTNULL(asset);
VLOG(0) << "Opening asset " << asset_filename << " from disk with copy.";
const off_t data_size = AAsset_getLength(asset);
const char* memory = reinterpret_cast<const char*>(AAsset_getBuffer(asset));

std::string tmp(memory, memory + data_size);
tmp.swap(*str);
AAsset_close(asset);
}

void ReadFileToVector(AAssetManager* const asset_manager,
const char* const filename,
std::vector<std::string>* str_vector) {
std::string labels_string;
ReadFileToString(asset_manager, filename, &labels_string);
std::istringstream ifs(labels_string);
str_vector->clear();
std::string label;
while (std::getline(ifs, label)) {
str_vector->push_back(label);
}
VLOG(0) << "Read " << str_vector->size() << " values from " << filename;
std::string GetString(JNIEnv* env, jstring java_string) {
const char* raw_string = env->GetStringUTFChars(java_string, 0);
std::string return_str(raw_string);
env->ReleaseStringUTFChars(java_string, raw_string);
return return_str;
}

void WriteProtoToFile(const char* const filename,
const google::protobuf::MessageLite& message) {
std::fstream outfile;
outfile.open(filename, std::fstream::binary | std::fstream::out);
std::string serialized;
message.SerializeToString(&serialized);
outfile.write(serialized.c_str(), serialized.size());
outfile.close();
if (outfile.fail()) {
LOG(WARNING) << "Failed to write proto to " << filename;
return;
}
VLOG(0) << "Wrote proto to " << filename;
tensorflow::int64 CurrentWallTimeUs() {
struct timeval tv;
gettimeofday(&tv, NULL);
return tv.tv_sec * 1000000 + tv.tv_usec;
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,22 @@ limitations under the License.
#include <string>
#include <vector>

#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"

namespace google {
namespace protobuf {
class MessageLite;
} // google
} // protobuf

class AAssetManager;

bool PortableReadFileToProto(const std::string& file_name,
::google::protobuf::MessageLite* proto);

void ReadFileToProto(AAssetManager* const asset_manager,
const char* const filename, google::protobuf::MessageLite* message);
::google::protobuf::MessageLite* proto)
TF_MUST_USE_RESULT;

void ReadFileToString(AAssetManager* const asset_manager,
const char* const filename, std::string* str);
// Deserializes the contents of a file into memory.
void ReadFileToProtoOrDie(AAssetManager* const asset_manager,
const char* const filename,
google::protobuf::MessageLite* message);

void ReadFileToVector(AAssetManager* const asset_manager,
const char* const filename, std::vector<std::string>* str_vector);
std::string GetString(JNIEnv* env, jstring java_string);

void WriteProtoToFile(const char* const filename,
const google::protobuf::MessageLite& message);
tensorflow::int64 CurrentWallTimeUs();

#endif // ORG_TENSORFLOW_JNI_JNI_UTILS_H_
Loading

0 comments on commit 754048a

Please sign in to comment.