Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,8 @@ using Reactant
Reactant.set_default_backend("tpu")
```

```julia [Tenstorrent (Experimental)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while updating this do we want to mention AMD gpus

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you set the backend? Is it just gpu like for Nvidia? Also, does it actually work? At least this backend can do a matmul 😛

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In principle it should just also be GPU

using Reactant
Reactant.set_default_backend("tt")
```
:::
1 change: 1 addition & 0 deletions src/accelerators/Accelerators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ module Accelerators

include("TPU.jl")
include("Metal.jl")
include("TT.jl")

end
82 changes: 82 additions & 0 deletions src/accelerators/TT.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
module TT

using Reactant: Reactant
using Scratch: @get_scratch!
using Downloads: Downloads
using p7zip_jll: p7zip

const tt_pjrt_plugin_dir = Ref{Union{Nothing,String}}(nothing)
const tt_pjrt_plugin_name = Ref{String}("pjrt_plugin_tt.so")

function __init__()
@static if Sys.islinux()
if !Reactant.precompiling() && has_tt()
setup_tt_pjrt_plugin!()
end
end
end

force_tt_init() = haskey(ENV, "REACTANT_FORCE_TT_INIT")

function has_tt()
if force_tt_init()
return true
end

# To find whether we have Tenstorrent devices, we can either
#
# * look for devices in `/dev/tenstorrent`, or
# * look for devices in `/sys/bus/pci/devices` with `vendor` equal to `0x1e52`, something like
# any(readchomp(joinpath(dir, "vendor")) == "0x1e52" for dir in readdir("/sys/bus/pci/devices"; join=true))
#
# The former is simpler for our current purposes, so we can go that way.
dev_tt = "/dev/tenstorrent"
return isdir(dev_tt) && length(readdir(dev_tt)) > 0
end

function setup_tt_pjrt_plugin!()
plugin_dir_from_env = get(ENV, "TT_PJRT_PLUGIN_DIR", nothing)
if plugin_dir_from_env !== nothing && ispath(plugin_dir_from_env)
tt_pjrt_plugin_dir[] = plugin_dir_from_env
else
tt_pjrt_plugin_dir[] = @get_scratch!("pjrt_plugin_tt")
end
download_tt_pjrt_plugin_if_needed(tt_pjrt_plugin_dir[])
return nothing
end

get_tt_pjrt_plugin_dir() = tt_pjrt_plugin_dir[]

function get_tt_pjrt_plugin_path()
return joinpath(get_tt_pjrt_plugin_dir(), tt_pjrt_plugin_name[])
end

function download_tt_pjrt_plugin_if_needed(dir=nothing)
dir === nothing && (dir = get_tt_pjrt_plugin_dir())
@assert dir !== nothing "tt_pjrt_plugin_dir is not set!"

tt_pjrt_plugin_path = joinpath(dir, tt_pjrt_plugin_name[])
if isfile(tt_pjrt_plugin_path)
@debug "TT PJRT plugin already found in '$(tt_pjrt_plugin_path)', nothing to do"
else
@debug "Will install the TT PJRT plugin to '$(tt_pjrt_plugin_path)'"
mktempdir() do tmp_dir
# Index at https://pypi.eng.aws.tenstorrent.com/pjrt-plugin-tt/
zip_file_path = joinpath(tmp_dir, "pjrt-plugin-tt.zip")
wheel_url = if Sys.ARCH === :x86_64
"https://pypi.eng.aws.tenstorrent.com/pjrt-plugin-tt/pjrt_plugin_tt-0.6.0.dev20251202-cp311-cp311-linux_x86_64.whl"
else
error("Unsupported architecture for TT PJRT plugin: $(Sys.ARCH)")
end
@debug "Downloading TT PJRT plugin from '$(wheel_url)'"
Downloads.download(wheel_url, zip_file_path)
run(pipeline(`$(p7zip()) x -tzip -o$(tmp_dir) -- $(zip_file_path)`, devnull))
data_dir = only(filter!(endswith(".data"), readdir(tmp_dir; join=true)))
# We need to move the entire `pjrt_plugin_tt` directory to the destination.
mv(joinpath(data_dir, "purelib", "pjrt_plugin_tt"), dir; force=true)
end
@assert isfile(tt_pjrt_plugin_path)
end
end

end # module TT
13 changes: 13 additions & 0 deletions src/xla/IFRT/Client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,14 @@ const cpu_client_count = Ref(0)
const cuda_client_count = Ref(0)
const tpu_client_count = Ref(0)
const metal_client_count = Ref(0)
const tt_client_count = Ref(0)

for (backend, counter) in (
(:CPUClient, :cpu_client_count),
(:CUDAClient, :cuda_client_count),
(:TPUClient, :tpu_client_count),
(:MetalClient, :metal_client_count),
(:TTClient, :tt_client_count),
)
main_fn = Symbol(:MakeIFRTPJRT, backend)
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
Expand Down Expand Up @@ -219,6 +221,17 @@ function MakeIFRTPJRTMetalClient(;
)
end

function MakeIFRTPJRTTTClient(;
tt_pjrt_plugin_path::String,
node_id::Integer=0,
num_nodes::Integer=1,
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
)
return MakeIFRTPJRTClientViaPluginAPI(
tt_pjrt_plugin_path, "tt", "TT"; node_id, num_nodes, distributed_runtime_client
)
end

function MakeIFRTPJRTClientViaPluginAPI(
library_path::String,
device_type::String,
Expand Down
16 changes: 16 additions & 0 deletions src/xla/PJRT/Client.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,14 @@ const cpu_client_count = Ref(0)
const cuda_client_count = Ref(0)
const tpu_client_count = Ref(0)
const metal_client_count = Ref(0)
const tt_client_count = Ref(0)

for (backend, counter) in (
(:CPUClient, :cpu_client_count),
(:CUDAClient, :cuda_client_count),
(:TPUClient, :tpu_client_count),
(:MetalClient, :metal_client_count),
(:TTClient, :tt_client_count),
)
main_fn = Symbol(:Make, backend)
@eval function $(backend)(args...; checkcount::Bool=true, kwargs...)
Expand Down Expand Up @@ -207,6 +209,20 @@ function MakeMetalClient(;
return MakeClientUsingPluginAPI(metal_pjrt_plugin_path, "metal", "METAL")
end

function MakeTTClient(;
tt_pjrt_plugin_path::String,
node_id::Integer=0,
num_nodes::Integer=1,
distributed_runtime_client::Union{Nothing,XLA.DistributedRuntimeClient}=nothing,
)
@assert node_id == 0 "`PJRT.MakeTTClient` does not support node_id"
@assert num_nodes == 1 "`PJRT.MakeTTClient` does not support num_nodes > 1"
@assert distributed_runtime_client === nothing "`PJRT.MakeTTClient` does not support \
distributed_runtime_client"

return MakeClientUsingPluginAPI(tt_pjrt_plugin_path, "tt", "TT")
end

function MakeClientUsingPluginAPI(
library_path::String, device_type::String, client_name::String=uppercase(device_type)
)
Expand Down
35 changes: 35 additions & 0 deletions src/xla/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,41 @@ for runtime in (:PJRT, :IFRT)
catch e
println(stdout, e)
end
elseif Accelerators.TT.has_tt()
@debug "TT accelerator detected, setting it up"
try
if was_initialized && haskey(state.clients, "tt")
free_client(state.clients["tt"])
$(runtime).tt_client_count[] -= 1
end
# The env var `TT_METAL_RUNTIME_ROOT` must be set before creating the client.
tt_metal_runtime_root = get(ENV, "TT_METAL_RUNTIME_ROOT", nothing)
if isnothing(tt_metal_runtime_root)
tt_metal_path_in_wheel = joinpath(
dirname(Accelerators.TT.get_tt_pjrt_plugin_path()),
"tt-metal",
)
if ispath(tt_metal_path_in_wheel)
@debug "Setting environment variable 'TT_METAL_RUNTIME_ROOT' to '$(tt_metal_path_in_wheel)'"
ENV["TT_METAL_RUNTIME_ROOT"] = tt_metal_path_in_wheel
else
error(
"`TT_METAL_RUNTIME_ROOT` environment variable not set and we could not automatically determine it",
)
end
else
@debug "Environment variable 'TT_METAL_RUNTIME_ROOT' already set to to '$(tt_metal_runtime_root)'"
end

tt = $(runtime).TTClient(;
tt_pjrt_plugin_path=Accelerators.TT.get_tt_pjrt_plugin_path(),
common_kwargs...,
)
state.clients["tt"] = tt
state.default_client = tt
catch e
println(stdout, e)
end
elseif Reactant_jll.host_platform.tags["gpu"] != "none"
try
if was_initialized && haskey(state.clients, "cuda")
Expand Down
Loading