Skip to content

Commit dda429c

Browse files
authored
Merge pull request #1059 from czgdp1807/dict09
Supporting tuples as keys in ``dict``
2 parents f4d6b11 + 72488f9 commit dda429c

File tree

5 files changed

+105
-3
lines changed

5 files changed

+105
-3
lines changed

integration_tests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ RUN(NAME test_tuple_02 LABELS cpython llvm)
160160
RUN(NAME test_dict_01 LABELS cpython llvm)
161161
RUN(NAME test_dict_02 LABELS cpython llvm)
162162
RUN(NAME test_dict_03 LABELS cpython llvm)
163+
RUN(NAME test_dict_04 LABELS cpython llvm)
163164
RUN(NAME modules_01 LABELS cpython llvm)
164165
RUN(NAME modules_02 LABELS cpython llvm)
165166
RUN(NAME test_math LABELS cpython llvm)

integration_tests/test_dict_04.py

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from ltypes import i32, i64, f64
2+
from math import pi, sin, cos
3+
4+
def test_dict():
5+
terms2poly: dict[tuple[i32, i32], i64] = {}
6+
rtheta2coords: dict[tuple[i64, i64], tuple[f64, f64]] = {}
7+
i: i32
8+
n: i64
9+
size: i32 = 7000
10+
size1: i32
11+
theta: f64
12+
r: f64
13+
coords: tuple[f64, f64]
14+
eps: f64 = 1e-12
15+
16+
n = 0
17+
for i in range(1000, 1000 + size, 7):
18+
terms2poly[(i, i*i)] = int(i + i*i)
19+
20+
theta = float(n) * pi
21+
r = float(i)
22+
rtheta2coords[(int(i), n)] = (r * sin(theta), r * cos(theta))
23+
24+
n += int(1)
25+
26+
size1 = size/7
27+
n = 0
28+
for i in range(1000, 1000 + size//2, 7):
29+
assert terms2poly.pop((i, i*i)) == int(i + i*i)
30+
31+
theta = float(n) * pi
32+
r = float(i)
33+
coords = rtheta2coords.pop((int(i), n))
34+
assert abs(coords[0] - r * sin(theta)) <= eps
35+
assert abs(coords[1] - r * cos(theta)) <= eps
36+
37+
size1 = size1 - 1
38+
assert len(terms2poly) == size1
39+
n += int(1)
40+
41+
n = 0
42+
for i in range(1000, 1000 + size//2, 7):
43+
terms2poly[(i, i*i)] = int(1 + 2*i + i*i)
44+
45+
theta = float(n) * pi
46+
r = float(i)
47+
rtheta2coords[(int(i), n)] = (r * cos(theta), r * sin(theta))
48+
49+
n += int(1)
50+
51+
n = 0
52+
for i in range(1000, 1000 + size//2, 7):
53+
assert terms2poly[(i, i*i)] == (i + 1)*(i + 1)
54+
55+
theta = float(n) * pi
56+
r = float(i)
57+
assert abs(rtheta2coords[(int(i), n)][0] - r * cos(theta)) <= eps
58+
assert abs(rtheta2coords[(int(i), n)][1] - r * sin(theta)) <= eps
59+
60+
n += int(1)
61+
62+
n = 0
63+
for i in range(1000, 1000 + size, 7):
64+
terms2poly[(i, i*i)] = int(1 + 2*i + i*i)
65+
66+
theta = float(n) * pi
67+
r = float(i)
68+
rtheta2coords[(int(i), n)] = (r * cos(theta), r * sin(theta))
69+
n += int(1)
70+
71+
n = 0
72+
for i in range(1000, 1000 + size, 7):
73+
assert terms2poly[(i, i*i)] == (i + 1)*(i + 1)
74+
75+
theta = float(n) * pi
76+
r = float(i)
77+
assert abs(r**2 - rtheta2coords[(int(i), n)][0]**2 - r**2 * sin(theta)**2) <= eps
78+
n += int(1)
79+
80+
test_dict()

src/libasr/codegen/asr_to_llvm.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
13691369
ptr_loads = !LLVM::is_llvm_struct(dict_type->m_key_type);
13701370
this->visit_expr_wrapper(x.m_key, true);
13711371
llvm::Value *key = tmp;
1372+
ptr_loads = !LLVM::is_llvm_struct(dict_type->m_value_type);
13721373
this->visit_expr_wrapper(x.m_value, true);
13731374
llvm::Value *value = tmp;
13741375
ptr_loads = ptr_loads_copy;

src/libasr/codegen/llvm_utils.cpp

+21-2
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,12 @@ namespace LFortran {
883883
occupancy_ptr);
884884

885885
llvm::Value* linear_prob_happened = builder->CreateICmpNE(key_hash, pos);
886+
linear_prob_happened = builder->CreateOr(linear_prob_happened,
887+
builder->CreateICmpEQ(
888+
LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key_mask, key_hash)),
889+
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 2)
890+
))
891+
);
886892
llvm::Value* set_max_2 = builder->CreateSelect(linear_prob_happened,
887893
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 2)),
888894
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
@@ -941,7 +947,7 @@ namespace LFortran {
941947
}
942948

943949
llvm::Value* LLVMDict::get_key_hash(llvm::Value* capacity, llvm::Value* key,
944-
ASR::ttype_t* key_asr_type, llvm::Module& /*module*/) {
950+
ASR::ttype_t* key_asr_type, llvm::Module& module) {
945951
// Write specialised hash functions for intrinsic types
946952
// This is to avoid unnecessary calls to C-runtime and do
947953
// as much as possible in LLVM directly.
@@ -951,11 +957,12 @@ namespace LFortran {
951957
// We can update it later to do a better hash function
952958
// which produces lesser collisions.
953959

954-
return builder->CreateZExtOrTrunc(
960+
llvm::Value* int_hash = builder->CreateZExtOrTrunc(
955961
builder->CreateSRem(key,
956962
builder->CreateZExtOrTrunc(capacity, key->getType())),
957963
capacity->getType()
958964
);
965+
return int_hash;
959966
}
960967
case ASR::ttypeType::Character: {
961968
// Polynomial rolling hash function for strings
@@ -1022,6 +1029,18 @@ namespace LFortran {
10221029
hash = builder->CreateTrunc(hash, llvm::Type::getInt32Ty(context));
10231030
return builder->CreateSRem(hash, capacity);
10241031
}
1032+
case ASR::ttypeType::Tuple: {
1033+
llvm::Value* tuple_hash = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0));
1034+
ASR::Tuple_t* asr_tuple = ASR::down_cast<ASR::Tuple_t>(key_asr_type);
1035+
for( size_t i = 0; i < asr_tuple->n_type; i++ ) {
1036+
llvm::Value* llvm_tuple_i = llvm_utils->tuple_api->read_item(key, i,
1037+
LLVM::is_llvm_struct(asr_tuple->m_type[i]));
1038+
tuple_hash = builder->CreateAdd(tuple_hash, get_key_hash(capacity, llvm_tuple_i,
1039+
asr_tuple->m_type[i], module));
1040+
tuple_hash = builder->CreateSRem(tuple_hash, capacity);
1041+
}
1042+
return tuple_hash;
1043+
}
10251044
default: {
10261045
throw LCompilersException("Hashing " + ASRUtils::type_to_str_python(key_asr_type) +
10271046
" isn't implemented yet.");

src/lpython/semantics/python_ast_to_asr.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -2198,7 +2198,8 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
21982198
} else if (ASR::is_a<ASR::Dict_t>(*type)) {
21992199
throw SemanticError("unhashable type in dict: 'slice'", loc);
22002200
}
2201-
} else if(AST::is_a<AST::Tuple_t>(*m_slice)) {
2201+
} else if(AST::is_a<AST::Tuple_t>(*m_slice) &&
2202+
!ASR::is_a<ASR::Dict_t>(*type)) {
22022203
bool final_result = true;
22032204
AST::Tuple_t* indices = AST::down_cast<AST::Tuple_t>(m_slice);
22042205
for( size_t i = 0; i < indices->n_elts; i++ ) {

0 commit comments

Comments
 (0)