diff --git a/CHANGELOG.md b/CHANGELOG.md index 8206b8d4..096c480d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## Next Release +### Updates +- Enable cell-specific language to run against the same Spark Session (for polyglot notebooks). + + ## 0.22.0 ### Updates diff --git a/examples/Multilang Notebook.ipynb b/examples/Multilang Notebook.ipynb new file mode 100644 index 00000000..9dd08445 --- /dev/null +++ b/examples/Multilang Notebook.ipynb @@ -0,0 +1,311 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "eb22e89b-bc39-424a-a205-68479ace406c", + "metadata": {}, + "source": [ + "# Mixing Spark and PySpark cells in the same Notebook\n", + "\n", + "Sparkmagic enables the use of Python, Scala, and R code cells within the same Notebook and SparkSession, allowing you to mix UDFs from different languages in a single DataFrame and leverage any Spark library—whether in Python, Scala, or R—in the language of your choice.\n", + "\n", + "**Note:** This notebook illustrates the use of Python and Scala, but the process for R is the same.\n", + "\n", + "**Note:** Remember to specify spark.jars or spark.PyFiles (as needed) when you want to import external packages into Spark." + ] + }, + { + "cell_type": "markdown", + "id": "964d3636-112c-450b-a5e2-b5f95902fe9f", + "metadata": {}, + "source": [ + "### Sharing UDFs\n", + "\n", + "Custom logic, such as UDFs, can be used in any language. This example shows how you can use Python functions in Scala" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "afffbedb-1e2b-4864-95ef-9624ca32ea10", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "df: org.apache.spark.sql.Dataset[Long] = [id: bigint]\n" + ] + } + ], + "source": [ + "%%spark -l scala\n", + "\n", + "val df = spark.range(10)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "49f22be2-3fb0-4328-b4dd-5bb0b962e203", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "" + ] + } + ], + "source": [ + "%%spark -l python\n", + "\n", + "\n", + "def plus_one(x):\n", + " return x + 1\n", + "\n", + "\n", + "def real_world_function(x, y, z):\n", + " # import pandas, networkx, scikit ...\n", + " pass\n", + "\n", + "spark.udf.register(\"plus_one\", plus_one, returnType=\"int\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "59793c3c-14b4-4d55-9b6c-b12a7284ca1b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "df2: org.apache.spark.sql.DataFrame = [id: bigint, col2: int]\n", + "== Physical Plan ==\n", + "*(2) Project [id#21L, pythonUDF0#29 AS col2#23]\n", + "+- BatchEvalPython [plus_one(id#21L)], [id#21L, pythonUDF0#29]\n", + " +- *(1) Range (0, 10, step=1, splits=4)\n", + "+---+----+\n", + "| id|col2|\n", + "+---+----+\n", + "| 0| 1|\n", + "| 1| 2|\n", + "| 2| 3|\n", + "| 3| 4|\n", + "| 4| 5|\n", + "| 5| 6|\n", + "| 6| 7|\n", + "| 7| 8|\n", + "| 8| 9|\n", + "| 9| 10|\n", + "+---+----+\n", + "\n" + ] + } + ], + "source": [ + "%%spark -l scala\n", + "\n", + "val df2 = df.withColumn(\"col2\", callUDF(\"plus_one\", $\"id\"))\n", + "df2.explain()\n", + "df2.show()" + ] + }, + { + "cell_type": "markdown", + "id": "8f8dd5e9-4d29-437d-90fe-75445cb9072f", + "metadata": {}, + "source": [ + "Logic can also be shared via Views:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "12e93be3-bacb-4b1a-9216-5710ff24c5ce", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%%spark -l python\n", + "\n", + "df = spark.range(10)\n", + "df.createTempView(\"for_scala\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "bc06e2e9-315c-4495-908a-9e0c623e17dc", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "otherDF: org.apache.spark.sql.DataFrame = [id: bigint, scala_col: bigint]\n" + ] + } + ], + "source": [ + "%%spark -l scala\n", + "\n", + "val otherDF = spark.range(10).withColumn(\"scala_col\", $\"id\" * 100)\n", + "\n", + "spark.table(\"for_scala\").join(otherDF, Seq(\"id\")).createOrReplaceTempView(\"for_python\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "2ee0abb1-c019-4eb0-af6e-9f9653ec3cf3", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "== Physical Plan ==\n", + "*(2) Project [id#38L, scala_col#75L]\n", + "+- *(2) BroadcastHashJoin [id#38L], [id#73L], Inner, BuildLeft\n", + " :- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false]))\n", + " : +- *(1) Range (0, 10, step=1, splits=4)\n", + " +- *(2) Project [id#73L, (id#73L * 100) AS scala_col#75L]\n", + " +- *(2) Range (0, 10, step=1, splits=4)\n", + "+---+---------+\n", + "| id|scala_col|\n", + "+---+---------+\n", + "| 0| 0|\n", + "| 1| 100|\n", + "| 2| 200|\n", + "| 3| 300|\n", + "| 4| 400|\n", + "| 5| 500|\n", + "| 6| 600|\n", + "| 7| 700|\n", + "| 8| 800|\n", + "| 9| 900|\n", + "+---+---------+" + ] + } + ], + "source": [ + "%%spark -l python\n", + "\n", + "spark.table(\"for_python\").explain()\n", + "spark.table(\"for_python\").show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21afe6c9-681c-41c3-801c-0934bf43acac", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Spark", + "language": "scala", + "name": "sparkkernel" + }, + "language_info": { + "codemirror_mode": "text/x-scala", + "file_extension": ".sc", + "mimetype": "text/x-scala", + "name": "scala", + "pygments_lexer": "scala" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sparkmagic/sparkmagic/kernels/kernelmagics.py b/sparkmagic/sparkmagic/kernels/kernelmagics.py index 1269947a..73e4cb64 100644 --- a/sparkmagic/sparkmagic/kernels/kernelmagics.py +++ b/sparkmagic/sparkmagic/kernels/kernelmagics.py @@ -359,6 +359,13 @@ def configure(self, line, cell="", local_ns=None): help="Whether to automatically coerce the types (default, pass True if being explicit) " "of the dataframe or not (pass False)", ) + @argument( + "-l", + "--language", + type=str, + default=None, + help=f"Specific language for the current cell (supported: {','.join(LANGS_SUPPORTED)})", + ) @wrap_unexpected_exceptions @handle_expected_exceptions def spark(self, line, cell="", local_ns=None): @@ -377,6 +384,7 @@ def spark(self, line, cell="", local_ns=None): args.samplefraction, None, coerce, + language=args.language, ) @cell_magic diff --git a/sparkmagic/sparkmagic/livyclientlib/command.py b/sparkmagic/sparkmagic/livyclientlib/command.py index cd4e67cc..62a05764 100644 --- a/sparkmagic/sparkmagic/livyclientlib/command.py +++ b/sparkmagic/sparkmagic/livyclientlib/command.py @@ -36,6 +36,7 @@ def __init__(self, code, spark_events=None): if spark_events is None: spark_events = SparkEvents() self._spark_events = spark_events + self.kind = None def __repr__(self): return "Command({}, ...)".format(repr(self.code)) @@ -46,14 +47,22 @@ def __eq__(self, other): def __ne__(self, other): return not self == other + def set_language(self, lang): + if lang is not None: + self.kind = conf.get_livy_kind(lang) + return self + def execute(self, session): + kind = self.kind or session.kind self._spark_events.emit_statement_execution_start_event( - session.guid, session.kind, session.id, self.guid + session.guid, kind, session.id, self.guid ) statement_id = -1 try: session.wait_for_idle() data = {"code": self.code} + if self.kind: + data["kind"] = kind response = session.http_client.post_statement(session.id, data) statement_id = response["id"] output = self._get_statement_output(session, statement_id) diff --git a/sparkmagic/sparkmagic/magics/sparkmagicsbase.py b/sparkmagic/sparkmagic/magics/sparkmagicsbase.py index 93fcc6f6..5a298d5d 100644 --- a/sparkmagic/sparkmagic/magics/sparkmagicsbase.py +++ b/sparkmagic/sparkmagic/magics/sparkmagicsbase.py @@ -124,6 +124,7 @@ def execute_spark( session_name, coerce, output_handler=None, + language=None, ): output_handler = output_handler or SparkOutputHandler( html=self.ipython_display.html, @@ -131,8 +132,9 @@ def execute_spark( default=self.ipython_display.display, ) + command = Command(cell).set_language(language) (success, out, mimetype) = self.spark_controller.run_command( - Command(cell), session_name + command, session_name ) if not success: if conf.shutdown_session_on_spark_statement_errors():