Skip to content

Commit adfef58

Browse files
committed
gpu + cpu
1 parent f7fccf9 commit adfef58

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

carte_ai/src/carte_model.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,24 @@
99
def _carte_calculate_attention(
1010
edge_index: Tensor, query: Tensor, key: Tensor, value: Tensor
1111
):
12+
## Fix to work on cpu and gpu provided by Ayoub Kachkach
1213
# Calculate the scaled-dot product attention
1314
attention = torch.sum(torch.mul(query[edge_index[0], :], key), dim=1)
1415
attention = attention / math.sqrt(query.size(1))
1516
attention = softmax(attention, edge_index[0])
16-
17+
18+
# Ensure `attention` and `value` have the same dtype
19+
attention = attention.to(value.dtype)
20+
1721
# Generate the output
1822
src = torch.mul(attention, value.t()).t()
19-
23+
24+
# Ensure `src` and `query` have the same dtype
25+
src = src.to(query.dtype)
26+
2027
# Use torch.index_add_ to replace scatter function
2128
output = torch.zeros_like(query).index_add_(0, edge_index[0], src)
22-
29+
2330
return output, attention
2431

2532

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "carte-ai"
7-
version = "0.0.23"
7+
version = "0.0.25"
88
description = "CARTE-AI: Context Aware Representation of Table Entries for AI"
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
requires-python = ">=3.10.12"

0 commit comments

Comments
 (0)