Skip to content

Commit

Permalink
[red-knot] Add support for unpacking union types (astral-sh#15052)
Browse files Browse the repository at this point in the history
## Summary

Refer:
astral-sh#13773 (comment)

This PR adds support for unpacking union types. 

Unpacking a union type requires us to first distribute the types for all
the targets that are involved in an unpacking. For example, if there are
two targets and a union type that needs to be unpacked, each target will
get a type from each element in the union type.

For example, if the type is `tuple[int, int] | tuple[int, str]` and the
target has two elements `(a, b)`, then
* The type of `a` will be a union of `int` and `int` which are at index
0 in the first and second tuple respectively which resolves to an `int`.
* Similarly, the type of `b` will be a union of `int` and `str` which
are at index 1 in the first and second tuple respectively which will be
`int | str`.

### Refactors

There are couple of refactors that are added in this PR:
* Add a `debug_assertion` to validate that the unpack target is a list
or a tuple
* Add a separate method to handle starred expression

## Test Plan

Update `unpacking.md` with additional test cases that uses union types.
This is done using parameter type hints style.
  • Loading branch information
dhruvmanila authored Dec 20, 2024
1 parent 089a98e commit d47fba1
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 81 deletions.
166 changes: 166 additions & 0 deletions crates/red_knot_python_semantic/resources/mdtest/unpacking.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,169 @@ reveal_type(b) # revealed: Unknown
reveal_type(a) # revealed: LiteralString
reveal_type(b) # revealed: LiteralString
```

## Union

### Same types

Union of two tuples of equal length and each element is of the same type.

```py
def _(arg: tuple[int, int] | tuple[int, int]):
(a, b) = arg
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int
```

### Mixed types (1)

Union of two tuples of equal length and one element differs in its type.

```py
def _(arg: tuple[int, int] | tuple[int, str]):
a, b = arg
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int | str
```

### Mixed types (2)

Union of two tuples of equal length and both the element types are different.

```py
def _(arg: tuple[int, str] | tuple[str, int]):
a, b = arg
reveal_type(a) # revealed: int | str
reveal_type(b) # revealed: str | int
```

### Mixed types (3)

Union of three tuples of equal length and various combination of element types:

1. All same types
1. One different type
1. All different types

```py
def _(arg: tuple[int, int, int] | tuple[int, str, bytes] | tuple[int, int, str]):
a, b, c = arg
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int | str
reveal_type(c) # revealed: int | bytes | str
```

### Nested

```py
def _(arg: tuple[int, tuple[str, bytes]] | tuple[tuple[int, bytes], Literal["ab"]]):
a, (b, c) = arg
reveal_type(a) # revealed: int | tuple[int, bytes]
reveal_type(b) # revealed: str
reveal_type(c) # revealed: bytes | LiteralString
```

### Starred expression

```py
def _(arg: tuple[int, bytes, int] | tuple[int, int, str, int, bytes]):
a, *b, c = arg
reveal_type(a) # revealed: int
# TODO: Should be `list[bytes | int | str]`
reveal_type(b) # revealed: @Todo(starred unpacking)
reveal_type(c) # revealed: int | bytes
```

### Size mismatch (1)

```py
def _(arg: tuple[int, bytes, int] | tuple[int, int, str, int, bytes]):
# TODO: Add diagnostic (too many values to unpack)
a, b = arg
reveal_type(a) # revealed: int
reveal_type(b) # revealed: bytes | int
```

### Size mismatch (2)

```py
def _(arg: tuple[int, bytes] | tuple[int, str]):
# TODO: Add diagnostic (there aren't enough values to unpack)
a, b, c = arg
reveal_type(a) # revealed: int
reveal_type(b) # revealed: bytes | str
reveal_type(c) # revealed: Unknown
```

### Same literal types

```py
def _(flag: bool):
if flag:
value = (1, 2)
else:
value = (3, 4)

a, b = value
reveal_type(a) # revealed: Literal[1, 3]
reveal_type(b) # revealed: Literal[2, 4]
```

### Mixed literal types

```py
def _(flag: bool):
if flag:
value = (1, 2)
else:
value = ("a", "b")

a, b = value
reveal_type(a) # revealed: Literal[1] | Literal["a"]
reveal_type(b) # revealed: Literal[2] | Literal["b"]
```

### Typing literal

```py
from typing import Literal

def _(arg: tuple[int, int] | Literal["ab"]):
a, b = arg
reveal_type(a) # revealed: int | LiteralString
reveal_type(b) # revealed: int | LiteralString
```

### Custom iterator (1)

```py
class Iterator:
def __next__(self) -> tuple[int, int] | tuple[int, str]:
return (1, 2)

class Iterable:
def __iter__(self) -> Iterator:
return Iterator()

((a, b), c) = Iterable()
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int | str
reveal_type(c) # revealed: tuple[int, int] | tuple[int, str]
```

### Custom iterator (2)

```py
class Iterator:
def __next__(self) -> bytes:
return b""

class Iterable:
def __iter__(self) -> Iterator:
return Iterator()

def _(arg: tuple[int, str] | Iterable):
a, b = arg
reveal_type(a) # revealed: int | bytes
reveal_type(b) # revealed: str | bytes
```
7 changes: 7 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,13 @@ impl<'db> Type<'db> {
.expect("Expected a Type::KnownInstance variant")
}

pub const fn into_tuple(self) -> Option<TupleType<'db>> {
match self {
Type::Tuple(tuple_type) => Some(tuple_type),
_ => None,
}
}

pub const fn is_boolean_literal(&self) -> bool {
matches!(self, Type::BooleanLiteral(..))
}
Expand Down
4 changes: 2 additions & 2 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult
let result = infer_expression_types(db, value);
let value_ty = result.expression_ty(value.node_ref(db).scoped_expression_id(db, scope));

let mut unpacker = Unpacker::new(db, file);
unpacker.unpack(unpack.target(db), value_ty, scope);
let mut unpacker = Unpacker::new(db, scope);
unpacker.unpack(unpack.target(db), value_ty);
unpacker.finish()
}

Expand Down
Loading

0 comments on commit d47fba1

Please sign in to comment.