Skip to content
30 changes: 24 additions & 6 deletions azure/functions/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
import datetime
import json
import re
import sys

from typing import Dict, Optional, Union, Tuple, Mapping, Any
if sys.version_info >= (3, 9):
from typing import get_origin, get_args
else:
from ._thirdparty.typing_inspect import get_origin, get_args

from ._thirdparty import typing_inspect
from ._utils import (
Expand All @@ -16,16 +22,28 @@


def is_iterable_type_annotation(annotation: object, pytype: object) -> bool:
is_iterable_anno = (
typing_inspect.is_generic_type(annotation)
and issubclass(typing_inspect.get_origin(annotation),
collections.abc.Iterable)
)
"""Since python 3.9, standard collection types are supported in type hint.
origin is the unsubscripted version of a type (eg. list, union, etc.).
If origin is not None, then the type annotation is a builtin or part of
the collections class.
"""
origin = get_origin(annotation)
if sys.version_info >= (3, 9):
is_iterable_anno = (origin is not None
and issubclass(origin, collections.abc.Iterable))
else:
is_iterable_anno = (
typing_inspect.is_generic_type(annotation)
and origin is not None
and issubclass(origin, collections.abc.Iterable)
)

if not is_iterable_anno:
return False

args = typing_inspect.get_args(annotation)
args = get_args(annotation)

if not args:
return False

Expand Down
6 changes: 6 additions & 0 deletions tests/test_eventgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Licensed under the MIT License.

from datetime import datetime
import sys
import unittest
from typing import List

Expand All @@ -21,6 +22,11 @@ def test_eventgrid_input_type(self):
def test_eventgrid_output_type(self):
check_output_type = azf_event_grid.EventGridEventOutConverter.\
check_output_type_annotation

if sys.version_info >= (3, 9):
self.assertTrue(check_output_type(list[func.EventGridOutputEvent]))
self.assertTrue(check_output_type(list[str]))

self.assertTrue(check_output_type(func.EventGridOutputEvent))
self.assertTrue(check_output_type(List[func.EventGridOutputEvent]))
self.assertTrue(check_output_type(str))
Expand Down
8 changes: 8 additions & 0 deletions tests/test_eventhub.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import sys
from typing import List, Mapping
import unittest
import json
Expand All @@ -21,6 +22,9 @@ def test_eventhub_input_type(self):
check_input_type = (
azf_eh.EventHubConverter.check_input_type_annotation
)
if sys.version_info >= (3, 9):
self.assertTrue(check_input_type(list[func.EventHubEvent]))

self.assertTrue(check_input_type(func.EventHubEvent))
self.assertTrue(check_input_type(List[func.EventHubEvent]))
self.assertFalse(check_input_type(str))
Expand All @@ -31,6 +35,10 @@ def test_eventhub_output_type(self):
check_output_type = (
azf_eh.EventHubTriggerConverter.check_output_type_annotation
)

if sys.version_info >= (3, 9):
self.assertTrue(check_output_type(list[str]))

self.assertTrue(check_output_type(bytes))
self.assertTrue(check_output_type(str))
self.assertTrue(check_output_type(List[str]))
Expand Down
9 changes: 9 additions & 0 deletions tests/test_kafka.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import sys
from typing import List
import unittest
import json
Expand Down Expand Up @@ -34,6 +35,10 @@ def test_kafka_input_type(self):
check_input_type = (
azf_ka.KafkaConverter.check_input_type_annotation
)

if sys.version_info >= (3, 9):
self.assertTrue(check_input_type(list[func.KafkaEvent]))

self.assertTrue(check_input_type(func.KafkaEvent))
self.assertTrue(check_input_type(List[func.KafkaEvent]))
self.assertFalse(check_input_type(str))
Expand All @@ -44,6 +49,10 @@ def test_kafka_output_type(self):
check_output_type = (
azf_ka.KafkaTriggerConverter.check_output_type_annotation
)

if sys.version_info >= (3, 9):
self.assertTrue(check_output_type(list[str]))

self.assertTrue(check_output_type(bytes))
self.assertTrue(check_output_type(str))
self.assertTrue(check_output_type(List[str]))
Expand Down