diff --git a/.gitignore b/.gitignore
index 26ae9a84..7f594401 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,6 +11,7 @@ README.org
 .cache/
 .history/
 .lib/
+.coverage
 dist/*
 target/
 lib_managed/
diff --git a/README.md b/README.md
index 5cd68ba7..fa4a16d9 100644
--- a/README.md
+++ b/README.md
@@ -131,16 +131,16 @@ Spark DataFrames are a natural construct for applying deep learning models to a
 
     ```python
     from sparkdl import readImages, TFImageTransformer
+    import sparkdl.graph.utils as tfx
     from sparkdl.transformers import utils
     import tensorflow as tf
 
-    g = tf.Graph()
-    with g.as_default():
+    graph = tf.Graph()
+    with tf.Session(graph=graph) as sess:
         image_arr = utils.imageInputPlaceholder()
         resized_images = tf.image.resize_images(image_arr, (299, 299))
-        # the following step is not necessary for this graph, but can be for graphs with variables, etc
-        frozen_graph = utils.stripAndFreezeGraph(g.as_graph_def(add_shapes=True), tf.Session(graph=g),
-                                                 [resized_images])
+        frozen_graph = tfx.strip_and_freeze_until([resized_images], graph, sess,
+                                                  return_graph=True)
 
     transformer = TFImageTransformer(inputCol="image", outputCol="predictions", graph=frozen_graph,
                                      inputTensor=image_arr, outputTensor=resized_images,
@@ -241,7 +241,7 @@ registerKerasImageUDF("my_keras_inception_udf", InceptionV3(weights="imagenet"),
 
 ```
 
+### Estimator
 
 ## Releases:
 * 0.1.0 initial release
-
diff --git a/python/requirements.txt b/python/requirements.txt
index a98a4d17..39981df5 100644
--- a/python/requirements.txt
+++ b/python/requirements.txt
@@ -9,3 +9,4 @@ pygments>=2.2.0
 tensorflow==1.3.0
 pandas>=0.19.1
 six>=1.10.0
+kafka-python>=1.3.5
diff --git a/python/sparkdl/estimators/tf_text_file_estimator.py b/python/sparkdl/estimators/tf_text_file_estimator.py
new file mode 100644
index 00000000..3f35e0be
--- /dev/null
+++ b/python/sparkdl/estimators/tf_text_file_estimator.py
@@ -0,0 +1,317 @@
+#
+# Copyright 2017 Databricks, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pylint: disable=protected-access
+from __future__ import absolute_import, division, print_function
+
+import logging
+import threading
+import time
+import os
+import sys
+
+from kafka import KafkaConsumer
+from kafka import KafkaProducer
+from pyspark.ml import Estimator
+
+from sparkdl.param import (
+    keyword_only, HasLabelCol, HasInputCol, HasOutputCol)
+from sparkdl.param.shared_params import KafkaParam, FitParam, MapFnParam
+import sparkdl.utils.jvmapi as JVMAPI
+
+if sys.version_info[:2] <= (2, 7):
+    import cPickle as pickle
+else:
+    import pickle
+
+__all__ = ['TFTextFileEstimator']
+
+logger = logging.getLogger('sparkdl')
+
+
+class TFTextFileEstimator(Estimator, HasInputCol, HasOutputCol, HasLabelCol, KafkaParam, FitParam, MapFnParam):
+    """
+    Build a Estimator from tensorflow or keras when backend is tensorflow.
+
+    First,assume we have data in dataframe like following.
+
+    .. code-block:: python
+            documentDF = self.session.createDataFrame([
+                                                        ("Hi I heard about Spark", 1),
+                                                        ("I wish Java could use case classes", 0),
+                                                        ("Logistic regression models are neat", 2)
+                                                        ], ["text", "preds"])
+
+            transformer = TFTextTransformer(
+                                            inputCol=input_col,
+                                            outputCol=output_col)
+
+            df = transformer.transform(documentDF)
+
+     TFTextTransformer will transform text column to  `output_col`, which is 2-D array.
+
+     Then we create a tensorflow function.
+
+     .. code-block:: python
+         def map_fun(_read_data, **args):
+            import tensorflow as tf
+            EMBEDDING_SIZE = args["embedding_size"]
+            feature = args['feature']
+            label = args['label']
+            params = args['params']['fitParam']
+            SEQUENCE_LENGTH = 64
+
+            def feed_dict(batch):
+                # Convert from dict of named arrays to two numpy arrays of the proper type
+                features = []
+                for i in batch:
+                    features.append(i['sentence_matrix'])
+
+                # print("{} {}".format(feature, features))
+                return features
+
+            encoder_variables_dict = {
+                "encoder_w1": tf.Variable(
+                    tf.random_normal([SEQUENCE_LENGTH * EMBEDDING_SIZE, 256]), name="encoder_w1"),
+                "encoder_b1": tf.Variable(tf.random_normal([256]), name="encoder_b1"),
+                "encoder_w2": tf.Variable(tf.random_normal([256, 128]), name="encoder_w2"),
+                "encoder_b2": tf.Variable(tf.random_normal([128]), name="encoder_b2")
+            }
+
+     _read_data is a data generator. args provide hyper parameteres configured in this estimator.
+
+     here is how to use _read_data:
+
+     .. code-block:: python
+        for data in _read_data(max_records=params.batch_size):
+            batch_data = feed_dict(data)
+            sess.run(train_step, feed_dict={input_x: batch_data})
+
+     finally we can create  TFTextFileEstimator to train our model:
+
+     .. code-block:: python
+            estimator = TFTextFileEstimator(inputCol="sentence_matrix",
+                                            outputCol="sentence_matrix", labelCol="preds",
+                                            kafkaParam={"bootstrap_servers": ["127.0.0.1"], "topic": "test",
+                                                    "group_id": "sdl_1"},
+                                            fitParam=[{"epochs": 5, "batch_size": 64}, {"epochs": 5, "batch_size": 1}],
+                                            mapFnParam=map_fun)
+            estimator.fit(df)
+
+    """
+
+    @keyword_only
+    def __init__(self, inputCol=None, outputCol=None, labelCol=None, kafkaParam=None, fitParam=None, mapFnParam=None):
+        super(TFTextFileEstimator, self).__init__()
+        kwargs = self._input_kwargs
+        self.setParams(**kwargs)
+
+    @keyword_only
+    def setParams(self, inputCol=None, outputCol=None, labelCol=None, kafkaParam=None, fitParam=None, mapFnParam=None):
+        kwargs = self._input_kwargs
+        return self._set(**kwargs)
+
+    def fit(self, dataset, params=None):
+        self._validateParams()
+        if params is None:
+            paramMaps = self.getFitParam()
+        elif isinstance(params, (list, tuple)):
+            if len(params) == 0:
+                paramMaps = [dict()]
+            else:
+                self._validateFitParams(params)
+                paramMaps = params
+        elif isinstance(params, dict):
+            paramMaps = [params]
+        else:
+            raise ValueError("Params must be either a param map or a list/tuple of param maps, "
+                             "but got %s." % type(params))
+        return self._fitInParallel(dataset, paramMaps)
+
+    def _validateParams(self):
+        """
+        Check Param values so we can throw errors on the driver, rather than workers.
+        :return: True if parameters are valid
+        """
+        if not self.isDefined(self.inputCol):
+            raise ValueError("Input column must be defined")
+        if not self.isDefined(self.outputCol):
+            raise ValueError("Output column must be defined")
+        return True
+
+    def _fitInParallel(self, dataset, paramMaps):
+
+        inputCol = self.getInputCol()
+        labelCol = self.getLabelCol()
+
+        from time import gmtime, strftime
+        kafaParams = self.getKafkaParam()
+        topic = kafaParams["topic"] + "_" + strftime("%Y-%m-%d-%H-%M-%S", gmtime())
+        group_id = kafaParams["group_id"]
+        bootstrap_servers = kafaParams["bootstrap_servers"]
+        kafka_test_mode = kafaParams["test_mode"] if "test_mode" in kafaParams else False
+        mock_kafka_file = kafaParams["mock_kafka_file"] if kafka_test_mode else None
+
+        def _write_data():
+            def _write_partition(index, d_iter):
+                producer = KafkaMockServer(index, mock_kafka_file) if kafka_test_mode else KafkaProducer(
+                    bootstrap_servers=bootstrap_servers)
+                try:
+                    for d in d_iter:
+                        producer.send(topic, pickle.dumps(d))
+                    producer.send(topic, pickle.dumps("_stop_"))
+                    producer.flush()
+                finally:
+                    producer.close()
+                return []
+
+            dataset.rdd.mapPartitionsWithIndex(_write_partition).count()
+
+        if kafka_test_mode:
+            _write_data()
+        else:
+            t = threading.Thread(target=_write_data)
+            t.start()
+
+        stop_flag_num = dataset.rdd.getNumPartitions()
+        temp_item = dataset.take(1)[0]
+        vocab_s = temp_item["vocab_size"]
+        embedding_size = temp_item["embedding_size"]
+
+        sc = JVMAPI._curr_sc()
+
+        paramMapsRDD = sc.parallelize(paramMaps, numSlices=len(paramMaps))
+
+        # Obtain params for this estimator instance
+        baseParamMap = self.extractParamMap()
+        baseParamDict = dict([(param.name, val) for param, val in baseParamMap.items()])
+        baseParamDictBc = sc.broadcast(baseParamDict)
+
+        def _local_fit(override_param_map):
+            # Update params
+            params = baseParamDictBc.value
+            params["fitParam"] = override_param_map
+
+            def _read_data(max_records=64):
+                consumer = KafkaMockServer(0, mock_kafka_file) if kafka_test_mode else KafkaConsumer(topic,
+                                                                                                     group_id=group_id,
+                                                                                                     bootstrap_servers=bootstrap_servers,
+                                                                                                     auto_offset_reset="earliest",
+                                                                                                     enable_auto_commit=False
+                                                                                                     )
+                try:
+                    stop_count = 0
+                    fail_msg_count = 0
+                    while True:
+                        if kafka_test_mode:
+                            time.sleep(1)
+                        messages = consumer.poll(timeout_ms=1000, max_records=max_records)
+                        group_msgs = []
+                        for tp, records in messages.items():
+                            for record in records:
+                                try:
+                                    msg_value = pickle.loads(record.value)
+                                    if msg_value == "_stop_":
+                                        stop_count += 1
+                                    else:
+                                        group_msgs.append(msg_value)
+                                except:
+                                    fail_msg_count += 0
+                                    pass
+                        if len(group_msgs) > 0:
+                            yield group_msgs
+
+                        if kafka_test_mode:
+                            print(
+                                "stop_count = {} "
+                                "group_msgs = {} "
+                                "stop_flag_num = {} "
+                                "fail_msg_count = {}".format(stop_count,
+                                                             len(group_msgs),
+                                                             stop_flag_num,
+                                                             fail_msg_count))
+
+                        if stop_count >= stop_flag_num and len(group_msgs) == 0:
+                            break
+                finally:
+                    consumer.close()
+
+            self.getMapFnParam()(args={"feature": inputCol,
+                                       "label": labelCol,
+                                       "vacab_size": vocab_s,
+                                       "embedding_size": embedding_size,
+                                       "params": params}, ctx=None, _read_data=_read_data,
+                                 )
+
+        return paramMapsRDD.map(lambda paramMap: (paramMap, _local_fit(paramMap)))
+
+    def _fit(self, dataset):  # pylint: disable=unused-argument
+        err_msgs = ["This function should not have been called",
+                    "Please contact library maintainers to file a bug"]
+        raise NotImplementedError('\n'.join(err_msgs))
+
+
+class KafkaMockServer(object):
+    """
+      Restrictions of KafkaMockServer:
+       * Make sure all data have been writen before consume.
+       * Poll function will just ignore max_records and just return all data in queue.
+    """
+
+    _kafka_mock_server_tmp_file_ = None
+    sended = False
+
+    def __init__(self, index=0, tmp_file=None):
+        super(KafkaMockServer, self).__init__()
+        self.index = index
+        self.queue = []
+        self._kafka_mock_server_tmp_file_ = tmp_file
+        if not os.path.exists(self._kafka_mock_server_tmp_file_):
+            os.mkdir(self._kafka_mock_server_tmp_file_)
+
+    def send(self, topic, msg):
+        self.queue.append(pickle.loads(msg))
+
+    def flush(self):
+        with open(self._kafka_mock_server_tmp_file_ + "/" + str(self.index), "wb") as f:
+            pickle.dump(self.queue, f)
+        self.queue = []
+
+    def close(self):
+        pass
+
+    def poll(self, timeout_ms, max_records):
+        if self.sended:
+            return {}
+
+        records = []
+        for file in os.listdir(self._kafka_mock_server_tmp_file_):
+            with open(self._kafka_mock_server_tmp_file_ + "/" + file, "rb") as f:
+                tmp = pickle.load(f)
+                records += tmp
+        result = {}
+        couter = 0
+        for i in records:
+            obj = MockRecord()
+            obj.value = pickle.dumps(i)
+            couter += 1
+            result[str(couter) + "_"] = [obj]
+        self.sended = True
+        return result
+
+
+class MockRecord(list):
+    pass
diff --git a/python/sparkdl/param/shared_params.py b/python/sparkdl/param/shared_params.py
index e169e891..7305fc8b 100644
--- a/python/sparkdl/param/shared_params.py
+++ b/python/sparkdl/param/shared_params.py
@@ -27,6 +27,7 @@
 
 import sparkdl.utils.keras_model as kmutil
 
+
 # From pyspark
 
 def keyword_only(func):
@@ -36,15 +37,75 @@ def keyword_only(func):
 
     .. note:: Should only be used to wrap a method where first arg is `self`
     """
+
     @wraps(func)
     def wrapper(self, *args, **kwargs):
         if len(args) > 0:
             raise TypeError("Method %s forces keyword arguments." % func.__name__)
         self._input_kwargs = kwargs
         return func(self, **kwargs)
+
     return wrapper
 
 
+class KafkaParam(Params):
+    kafkaParam = Param(Params._dummy(), "kafkaParam", "kafka", typeConverter=TypeConverters.identity)
+
+    def __init__(self):
+        super(KafkaParam, self).__init__()
+
+    def setKafkaParam(self, value):
+        """
+        Sets the value of :py:attr:`inputCol`.
+        """
+        return self._set(kafkaParam=value)
+
+    def getKafkaParam(self):
+        """
+        Gets the value of inputCol or its default value.
+        """
+        return self.getOrDefault(self.kafkaParam)
+
+
+class FitParam(Params):
+    fitParam = Param(Params._dummy(), "fitParam", "hyper parameter when training",
+                     typeConverter=TypeConverters.identity)
+
+    def __init__(self):
+        super(FitParam, self).__init__()
+
+    def setFitParam(self, value):
+        """
+        Sets the value of :py:attr:`inputCol`.
+        """
+        return self._set(fitParam=value)
+
+    def getFitParam(self):
+        """
+        Gets the value of inputCol or its default value.
+        """
+        return self.getOrDefault(self.fitParam)
+
+
+class MapFnParam(Params):
+    mapFnParam = Param(Params._dummy(), "mapFnParam", "Tensorflow func", typeConverter=TypeConverters.identity)
+
+    def __init__(self):
+        super(MapFnParam, self).__init__()
+
+    def setMapFnParam(self, value):
+        """
+        Sets the value of :py:attr:`inputCol`.
+        """
+        return self._set(mapFnParam=value)
+
+    def getMapFnParam(self):
+        """
+        Gets the value of inputCol or its default value.
+        """
+        return self.getOrDefault(self.mapFnParam)
+
+
 class HasInputCol(Params):
     """
     Mixin for param inputCol: input column name.
@@ -68,6 +129,42 @@ def getInputCol(self):
         return self.getOrDefault(self.inputCol)
 
 
+class HasEmbeddingSize(Params):
+    """
+    Mixin for param embeddingSize
+    """
+
+    embeddingSize = Param(Params._dummy(), "embeddingSize", "word embedding size",
+                          typeConverter=TypeConverters.toInt)
+
+    def __init__(self):
+        super(HasEmbeddingSize, self).__init__()
+
+    def setEmbeddingSize(self, value):
+        return self._set(embeddingSize=value)
+
+    def getEmbeddingSize(self):
+        return self.getOrDefault(self.embeddingSize)
+
+
+class HasSequenceLength(Params):
+    """
+    Mixin for param sequenceLength
+    """
+
+    sequenceLength = Param(Params._dummy(), "sequenceLength", "sequence length",
+                           typeConverter=TypeConverters.toInt)
+
+    def __init__(self):
+        super(HasSequenceLength, self).__init__()
+
+    def setSequenceLength(self, value):
+        return self._set(sequenceLength=value)
+
+    def getSequenceLength(self):
+        return self.getOrDefault(self.sequenceLength)
+
+
 class HasOutputCol(Params):
     """
     Mixin for param outputCol: output column name.
@@ -92,12 +189,12 @@ def getOutputCol(self):
         """
         return self.getOrDefault(self.outputCol)
 
+
 ############################################
 # New in sparkdl
 ############################################
 
 class SparkDLTypeConverters(object):
-
     @staticmethod
     def toStringOrTFTensor(value):
         if isinstance(value, tf.Tensor):
diff --git a/python/sparkdl/transformers/tf_text.py b/python/sparkdl/transformers/tf_text.py
new file mode 100644
index 00000000..5f6b95a7
--- /dev/null
+++ b/python/sparkdl/transformers/tf_text.py
@@ -0,0 +1,104 @@
+# Copyright 2017 Databricks, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import numpy as np
+from pyspark.ml import Transformer
+from pyspark.ml.feature import Word2Vec
+from pyspark.sql.functions import udf
+from pyspark.sql import functions as f
+from pyspark.sql.types import *
+from pyspark.sql.functions import lit
+from sparkdl.param.shared_params import HasEmbeddingSize, HasSequenceLength
+from sparkdl.param import (
+    keyword_only, HasInputCol, HasOutputCol)
+import re
+
+import sparkdl.utils.jvmapi as JVMAPI
+
+
+class TFTextTransformer(Transformer, HasInputCol, HasOutputCol, HasEmbeddingSize, HasSequenceLength):
+    """
+    Convert sentence/document to a 2-D Array eg. [[word embedding],[....]]  in DataFrame which can be processed
+    directly by tensorflow or keras who's backend is tensorflow.
+
+    Processing Steps:
+
+    * Using Word2Vec compute Map(word -> vector) from input column, then broadcast the map.
+    * Process input column (which is text),split it with white space, replace word with vector, padding the result to
+      the same size.
+    * Create a new dataframe with columns like new 2-D array , vocab_size, embedding_size
+    * return then new dataframe
+    """
+    VOCAB_SIZE = 'vocab_size'
+    EMBEDDING_SIZE = 'embedding_size'
+
+    @keyword_only
+    def __init__(self, inputCol=None, outputCol=None, embeddingSize=100, sequenceLength=64):
+        super(TFTextTransformer, self).__init__()
+        kwargs = self._input_kwargs
+        self.setParams(**kwargs)
+
+    @keyword_only
+    def setParams(self, inputCol=None, outputCol=None, embeddingSize=100, sequenceLength=64):
+        kwargs = self._input_kwargs
+        return self._set(**kwargs)
+
+    def _transform(self, dataset):
+        sc = JVMAPI._curr_sc()
+
+        word2vec = Word2Vec(vectorSize=self.getEmbeddingSize(), minCount=1, inputCol=self.getInputCol(),
+                            outputCol="word_embedding")
+
+        vectorsDf = word2vec.fit(
+            dataset.select(f.split(self.getInputCol(), "\\s+").alias(self.getInputCol()))).getVectors()
+
+        """
+          It's strange here that after calling getVectors the df._sc._jsc will lose and this is
+          only happens when you run it with ./python/run-tests.sh script.
+          We add this code to make it pass the test. However it seems this will hit
+          "org.apache.spark.SparkException: EOF reached before Python server acknowledged" error.
+        """
+        if vectorsDf._sc._jsc is None:
+            vectorsDf._sc._jsc = sc._jsc
+
+        word_embedding = dict(vectorsDf.rdd.map(
+            lambda p: (p.word, p.vector.values.tolist())).collect())
+
+        word_embedding["unk"] = np.zeros(self.getEmbeddingSize()).tolist()
+        local_word_embedding = sc.broadcast(word_embedding)
+
+        def convert_word_to_index(s):
+            def _pad_sequences(sequences, maxlen=None):
+                new_sequences = []
+
+                if len(sequences) <= maxlen:
+                    for i in range(maxlen - len(sequences)):
+                        new_sequences.append(np.zeros(self.getEmbeddingSize()).tolist())
+                    return sequences + new_sequences
+                else:
+                    return sequences[0:maxlen]
+
+            new_q = [local_word_embedding.value[word] for word in re.split(r"\s+", s) if
+                     word in local_word_embedding.value.keys()]
+            result = _pad_sequences(new_q, maxlen=self.getSequenceLength())
+            return result
+
+        cwti_udf = udf(convert_word_to_index, ArrayType(ArrayType(FloatType())))
+        doc_martic = (dataset.withColumn(self.getOutputCol(), cwti_udf(self.getInputCol()).alias(self.getOutputCol()))
+                      .withColumn(self.VOCAB_SIZE, lit(len(word_embedding)))
+                      .withColumn(self.EMBEDDING_SIZE, lit(self.getEmbeddingSize()))
+                      )
+
+        return doc_martic
diff --git a/python/sparkdl/transformers/utils.py b/python/sparkdl/transformers/utils.py
index b244365b..9964f3df 100644
--- a/python/sparkdl/transformers/utils.py
+++ b/python/sparkdl/transformers/utils.py
@@ -18,6 +18,8 @@
 # image stuff
 
 IMAGE_INPUT_PLACEHOLDER_NAME = "sparkdl_image_input"
+TEXT_INPUT_PLACEHOLDER_NAME = "sparkdl_text_input"
+
 
 def imageInputPlaceholder(nChannels=None):
     return tf.placeholder(tf.float32, [None, None, None, nChannels],
diff --git a/python/tests/resources/text/sample.txt b/python/tests/resources/text/sample.txt
new file mode 100644
index 00000000..8c5e8d99
--- /dev/null
+++ b/python/tests/resources/text/sample.txt
@@ -0,0 +1,4 @@
+接下 来 介绍 一种 非常 重要 的 神经网络 卷积神经网络
+这种 神经 网络 在 计算机 视觉 领域 取得了 重大 的 成功,而且 在 自然语言 处理 等 其它 领域 也有 很好 应用
+深度学习 受到 大家 关注 很大 一个 原因 就是 Alex 实现 AlexNet( 一种 深度卷积神经网络 )在 LSVRC-2010 ImageNet
+此后 卷积神经网络 及其 变种 被广泛 应用于 各种图像 相关 任务
\ No newline at end of file
diff --git a/python/tests/transformers/tf_text_test.py b/python/tests/transformers/tf_text_test.py
new file mode 100644
index 00000000..99c505be
--- /dev/null
+++ b/python/tests/transformers/tf_text_test.py
@@ -0,0 +1,223 @@
+# Copyright 2017 Databricks, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import shutil
+import threading
+import sys
+from sparkdl.estimators.tf_text_file_estimator import TFTextFileEstimator, KafkaMockServer
+from sparkdl.transformers.tf_text import TFTextTransformer
+from ..tests import SparkDLTestCase
+
+if sys.version_info[:2] <= (2, 7):
+    import cPickle as pickle
+else:
+    import pickle
+
+
+def map_fun(args={}, ctx=None, _read_data=None):
+    import tensorflow as tf
+    EMBEDDING_SIZE = args["embedding_size"]
+    params = args['params']['fitParam']
+    SEQUENCE_LENGTH = 64
+
+    def feed_dict(batch):
+        # Convert from dict of named arrays to two numpy arrays of the proper type
+        features = []
+        for i in batch:
+            features.append(i['sentence_matrix'])
+
+        # print("{} {}".format(feature, features))
+        return features
+
+    encoder_variables_dict = {
+        "encoder_w1": tf.Variable(
+            tf.random_normal([SEQUENCE_LENGTH * EMBEDDING_SIZE, 256]), name="encoder_w1"),
+        "encoder_b1": tf.Variable(tf.random_normal([256]), name="encoder_b1"),
+        "encoder_w2": tf.Variable(tf.random_normal([256, 128]), name="encoder_w2"),
+        "encoder_b2": tf.Variable(tf.random_normal([128]), name="encoder_b2")
+    }
+
+    def encoder(x, name="encoder"):
+        with tf.name_scope(name):
+            encoder_w1 = encoder_variables_dict["encoder_w1"]
+            encoder_b1 = encoder_variables_dict["encoder_b1"]
+
+            layer_1 = tf.nn.sigmoid(tf.matmul(x, encoder_w1) + encoder_b1)
+
+            encoder_w2 = encoder_variables_dict["encoder_w2"]
+            encoder_b2 = encoder_variables_dict["encoder_b2"]
+
+            layer_2 = tf.nn.sigmoid(tf.matmul(layer_1, encoder_w2) + encoder_b2)
+            return layer_2
+
+    def decoder(x, name="decoder"):
+        with tf.name_scope(name):
+            decoder_w1 = tf.Variable(tf.random_normal([128, 256]))
+            decoder_b1 = tf.Variable(tf.random_normal([256]))
+
+            layer_1 = tf.nn.sigmoid(tf.matmul(x, decoder_w1) + decoder_b1)
+
+            decoder_w2 = tf.Variable(
+                tf.random_normal([256, SEQUENCE_LENGTH * EMBEDDING_SIZE]))
+            decoder_b2 = tf.Variable(
+                tf.random_normal([SEQUENCE_LENGTH * EMBEDDING_SIZE]))
+
+            layer_2 = tf.nn.sigmoid(tf.matmul(layer_1, decoder_w2) + decoder_b2)
+            return layer_2
+
+    tf.reset_default_graph
+    sess = tf.Session()
+
+    input_x = tf.placeholder(tf.float32, [None, SEQUENCE_LENGTH, EMBEDDING_SIZE], name="input_x")
+    flattened = tf.reshape(input_x,
+                           [-1, SEQUENCE_LENGTH * EMBEDDING_SIZE])
+
+    encoder_op = encoder(flattened)
+
+    tf.add_to_collection('encoder_op', encoder_op)
+
+    y_pred = decoder(encoder_op)
+
+    y_true = flattened
+
+    with tf.name_scope("xent"):
+        consine = tf.div(tf.reduce_sum(tf.multiply(y_pred, y_true), 1),
+                         tf.multiply(tf.sqrt(tf.reduce_sum(tf.multiply(y_pred, y_pred), 1)),
+                                     tf.sqrt(tf.reduce_sum(tf.multiply(y_true, y_true), 1))))
+        xent = tf.reduce_sum(tf.subtract(tf.constant(1.0), consine))
+        tf.summary.scalar("xent", xent)
+
+    with tf.name_scope("train"):
+        # train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(xent)
+        train_step = tf.train.RMSPropOptimizer(0.01).minimize(xent)
+
+    summ = tf.summary.merge_all()
+
+    sess.run(tf.global_variables_initializer())
+
+    for i in range(params["epochs"]):
+        print("epoll {}".format(i))
+        for data in _read_data(max_records=params["batch_size"]):
+            batch_data = feed_dict(data)
+            sess.run(train_step, feed_dict={input_x: batch_data})
+
+    sess.close()
+
+
+class TFTextTransformerTest(SparkDLTestCase):
+    def test_convertText(self):
+        input_col = "text"
+        output_col = "sentence_matrix"
+
+        documentDF = self.session.createDataFrame([
+            ("Hi I heard about Spark", 1),
+            ("I wish Java could use case classes", 0),
+            ("Logistic regression models are neat", 2)
+        ], ["text", "preds"])
+
+        # transform text column to sentence_matrix column which contains 2-D array.
+        transformer = TFTextTransformer(
+            inputCol=input_col, outputCol=output_col, embeddingSize=100, sequenceLength=64)
+
+        df = transformer.transform(documentDF)
+        data = df.collect()
+        self.assertEquals(len(data), 3)
+        for row in data:
+            self.assertEqual(len(row[output_col]), 64)
+            self.assertEqual(len(row[output_col][0]), 100)
+
+
+class TFTextFileEstimatorTest(SparkDLTestCase):
+    def test_trainText(self):
+        input_col = "text"
+        output_col = "sentence_matrix"
+
+        documentDF = self.session.createDataFrame([
+            ("Hi I heard about Spark", 1),
+            ("I wish Java could use case classes", 0),
+            ("Logistic regression models are neat", 2)
+        ], ["text", "preds"])
+
+        # transform text column to sentence_matrix column which contains 2-D array.
+        transformer = TFTextTransformer(
+            inputCol=input_col, outputCol=output_col, embeddingSize=100, sequenceLength=64)
+
+        df = transformer.transform(documentDF)
+        import tempfile
+        mock_kafka_file = tempfile.mkdtemp()
+        # create a estimator to training where map_fun contains tensorflow's code
+        estimator = TFTextFileEstimator(inputCol="sentence_matrix", outputCol="sentence_matrix", labelCol="preds",
+                                        kafkaParam={"bootstrap_servers": ["127.0.0.1"], "topic": "test",
+                                                    "mock_kafka_file": mock_kafka_file,
+                                                    "group_id": "sdl_1", "test_mode": True},
+                                        fitParam=[{"epochs": 5, "batch_size": 64}, {"epochs": 5, "batch_size": 1}],
+                                        mapFnParam=map_fun)
+        estimator.fit(df).collect()
+        shutil.rmtree(mock_kafka_file)
+
+
+class MockKakfaServerTest(SparkDLTestCase):
+    def test_mockKafkaServerProduce(self):
+        import tempfile
+        mock_kafka_file = tempfile.mkdtemp()
+        dataset = self.session.createDataFrame([
+            ("Hi I heard about Spark", 1),
+            ("I wish Java could use case classes", 0),
+            ("Logistic regression models are neat", 2)
+        ], ["text", "preds"])
+
+        def _write_data():
+            def _write_partition(index, d_iter):
+                producer = KafkaMockServer(index, mock_kafka_file)
+                try:
+                    for d in d_iter:
+                        producer.send("", pickle.dumps(d))
+                    producer.send("", pickle.dumps("_stop_"))
+                    producer.flush()
+                finally:
+                    producer.close()
+                return []
+
+            dataset.rdd.mapPartitionsWithIndex(_write_partition).count()
+
+        _write_data()
+
+        def _consume():
+            consumer = KafkaMockServer(0, mock_kafka_file)
+            stop_count = 0
+            while True:
+                messages = consumer.poll(timeout_ms=1000, max_records=64)
+                group_msgs = []
+                for tp, records in messages.items():
+                    for record in records:
+                        try:
+                            msg_value = pickle.loads(record.value)
+                            print(msg_value)
+                            if msg_value == "_stop_":
+                                stop_count += 1
+                            else:
+                                group_msgs.append(msg_value)
+                        except:
+                            pass
+                if stop_count >= 8:
+                    break
+            self.assertEquals(stop_count, 8)
+
+            t = threading.Thread(target=_consume)
+            t.start()
+            t2 = threading.Thread(target=_consume)
+            t2.start()
+            import time
+            time.sleep(10)
+            shutil.rmtree(mock_kafka_file)