2
2
3
3
"""Tests for `samgeo` package."""
4
4
5
-
5
+ import os
6
6
import unittest
7
7
8
8
from samgeo import samgeo
@@ -13,9 +13,59 @@ class TestSamgeo(unittest.TestCase):
13
13
14
14
def setUp (self ):
15
15
"""Set up test fixtures, if any."""
16
+ bbox = [- 122.1497 , 37.6311 , - 122.1203 , 37.6458 ]
17
+ image = "satellite.tif"
18
+ samgeo .tms_to_geotiff (
19
+ output = image , bbox = bbox , zoom = 15 , source = "Satellite" , overwrite = True
20
+ )
21
+ self .source = image
22
+
23
+ out_dir = os .path .join (os .path .expanduser ("~" ), "Downloads" )
24
+ checkpoint = os .path .join (out_dir , "sam_vit_h_4b8939.pth" )
25
+ self .checkpoint = checkpoint
26
+
27
+ sam = samgeo .SamGeo (
28
+ model_type = "vit_h" ,
29
+ checkpoint = checkpoint ,
30
+ sam_kwargs = None ,
31
+ )
32
+
33
+ self .sam = sam
16
34
17
35
def tearDown (self ):
18
36
"""Tear down test fixtures, if any."""
19
37
20
- def test_000_something (self ):
21
- """Test something."""
38
+ def test_generate (self ):
39
+ """Test the automatic generation of masks and annotations.
40
+ """
41
+ sam = self .sam
42
+ source = self .source
43
+
44
+ sam .generate (source , output = "masks.tif" , foreground = True , unique = True )
45
+ self .assertTrue (os .path .exists ("masks.tif" ))
46
+
47
+ sam .show_anns (axis = "off" , alpha = 1 , output = "annotations.tif" )
48
+ self .assertTrue (os .path .exists ("annotations.tif" ))
49
+
50
+ sam .tiff_to_vector ("masks.tif" , "masks.gpkg" )
51
+ self .assertTrue (os .path .exists ("masks.gpkg" ))
52
+
53
+
54
+ def test_predict (self ):
55
+ """Test the prediction of masks and annotations based on input prompts.
56
+ """
57
+ sam = samgeo .SamGeo (
58
+ model_type = "vit_h" ,
59
+ checkpoint = self .checkpoint ,
60
+ automatic = False ,
61
+ sam_kwargs = None ,
62
+ )
63
+
64
+ sam .set_image (self .source )
65
+ point_coords = [[- 122.1419 , 37.6383 ]]
66
+ sam .predict (point_coords , point_labels = 1 , point_crs = "EPSG:4326" , output = 'mask1.tif' )
67
+ self .assertTrue (os .path .exists ("mask1.tif" ))
68
+
69
+ point_coords = [[- 122.1464 , 37.6431 ], [- 122.1449 , 37.6415 ], [- 122.1451 , 37.6395 ]]
70
+ sam .predict (point_coords , point_labels = 1 , point_crs = "EPSG:4326" , output = 'mask2.tif' )
71
+ self .assertTrue (os .path .exists ("mask2.tif" ))
0 commit comments