forked from viam-modules/torch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
27 changed files
with
1,226 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
MODULE_DIR=$(dirname $0) | ||
VIRTUAL_ENV=$MODULE_DIR/.venv | ||
PYTHON=$VIRTUAL_ENV/bin/python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,82 @@ | ||
# torch | ||
# VIAM PYTORCH ML MODEL | ||
***in progress*** | ||
|
||
This is a [Viam module](https://docs.viam.com/extend/modular-resources/) providing a mlmodel service for PyTorch model | ||
|
||
## Getting started | ||
|
||
|
||
## Installation with `pip install` | ||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Configure your `mlmodel:torch-cpu` vision service | ||
|
||
> [!NOTE] | ||
> Before configuring your vision service, you must [create a robot](https://docs.viam.com/manage/fleet/robots/#add-a-new-robot). | ||
Navigate to the **Config** tab of your robot’s page in [the Viam app](https://app.viam.com/). Click on the **Services** subtab and click **Create service**. Select the `Vision` type, then select the `deepface_identification` model. Enter a name for your service and click **Create**. | ||
|
||
### Example | ||
|
||
|
||
```json | ||
{ | ||
"modules": [ | ||
{ | ||
"executable_path": "/Users/robinin/torch-infer/torch/run.sh", | ||
"name": "mymodel", | ||
"type": "local" | ||
} | ||
], | ||
"services": [ | ||
{ | ||
"name": "torch", | ||
"type": "mlmodel", | ||
"model": "viam:mlmodel:torch-cpu", | ||
"attributes": { | ||
"model_path": "examples/resnet_18/resnet-18.pt", | ||
"label_path": "examples/resnet_18/labels.txt", | ||
} | ||
} | ||
] | ||
} | ||
``` | ||
|
||
|
||
### Attributes description | ||
|
||
The following attributes are available to configure your module: | ||
|
||
|
||
| Name | Type | Inclusion | Default | Description | | ||
| ------------ | ------ | ------------ | ------- | --------------------------------- | | ||
| `model_path` | string | **Required** | | Path to **standalone** model file | | ||
| `label_path` | string | Optional | | Path to file with class labels. | | ||
|
||
|
||
|
||
|
||
# Methods | ||
## `infer()` | ||
``` | ||
infer(input_tensors: Dict[str, NDArray], *, timeout: Optional[float]) -> Dict[str, NDArray] | ||
``` | ||
|
||
### Example | ||
|
||
```python | ||
my_model = MLModelClient.from_robot(robot, "torch") | ||
input_image = np.array(Image.open(path_to_input_image), dtype=np.float32) | ||
input_image = np.transpose(input_image, (2, 0, 1)) # channel first | ||
input_image = np.expand_dims(input_image, axis=0) # batch dim | ||
input_tensor = dict() | ||
input_tensor["input"] = input_image | ||
output = await my_model.infer(input_tensor) | ||
print(f"output.shape is {output['output'].shape}") | ||
``` | ||
|
||
## `metadata()` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
{ | ||
"createdOn": "04/05/2021 21:35:21", | ||
"runtime": "python", | ||
"model": { | ||
"modelName": "resnet-18", | ||
"serializedFile": "resnet18-f37072fd.pth", | ||
"handler": "image_classifier", | ||
"modelFile": "model.py", | ||
"modelVersion": "1.0" | ||
}, | ||
"archiverVersion": "0.3.0" | ||
} |
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from torchvision.models.resnet import ResNet, BasicBlock | ||
|
||
|
||
class ImageClassifier(ResNet): | ||
def __init__(self): | ||
super(ImageClassifier, self).__init__(BasicBlock, [2, 2, 2, 2]) | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"createdOn": "04/05/2021 21:39:32", | ||
"runtime": "python", | ||
"model": { | ||
"modelName": "resnet-18_scripted", | ||
"serializedFile": "resnet-18.pt", | ||
"handler": "image_classifier", | ||
"modelVersion": "1.0" | ||
}, | ||
"archiverVersion": "0.3.0" | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
viam-sdk | ||
numpy | ||
torch==2.2.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
#!/usr/bin/env bash | ||
|
||
# bash safe mode. look at `set --help` to see what these are doing | ||
set -euxo pipefail | ||
|
||
cd $(dirname $0) | ||
source .env | ||
./setup.sh | ||
|
||
# Be sure to use `exec` so that termination signals reach the python process, | ||
# or handle forwarding termination signals manually | ||
exec $PYTHON -m src.main $@ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#!/usr/bin/env bash | ||
# setup.sh -- environment bootstrapper for python virtualenv | ||
|
||
set -euo pipefail | ||
|
||
SUDO=sudo | ||
if ! command -v $SUDO; then | ||
echo no sudo on this system, proceeding as current user | ||
SUDO="" | ||
fi | ||
|
||
if command -v apt-get; then | ||
if dpkg -l python3-venv; then | ||
echo "python3-venv is installed, skipping setup" | ||
else | ||
if ! apt info python3-venv; then | ||
echo python3-venv package info not found, trying apt update | ||
$SUDO apt-get -qq update | ||
fi | ||
$SUDO apt-get install -qqy python3-venv | ||
fi | ||
else | ||
echo Skipping tool installation because your platform is missing apt-get. | ||
echo If you see failures below, install the equivalent of python3-venv for your system. | ||
fi | ||
|
||
source .env | ||
echo creating virtualenv at $VIRTUAL_ENV | ||
python3 -m venv $VIRTUAL_ENV | ||
echo installing dependencies from requirements.txt | ||
$VIRTUAL_ENV/bin/pip install -r requirements.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import asyncio | ||
|
||
|
||
from viam.module.module import Module | ||
from viam.resource.registry import Registry, ResourceCreatorRegistration | ||
from .torch_mlmodel_module import TorchMLModelModule | ||
from viam.services.mlmodel import MLModel | ||
|
||
|
||
async def main(): | ||
""" | ||
This function creates and starts a new module, after adding all desired | ||
resource models. Resource creators must be registered to the resource | ||
registry before the module adds the resource model. | ||
""" | ||
Registry.register_resource_creator( | ||
MLModel.SUBTYPE, | ||
TorchMLModelModule.MODEL, | ||
ResourceCreatorRegistration( | ||
TorchMLModelModule.new_service, TorchMLModelModule.validate_config | ||
), | ||
) | ||
module = Module.from_args() | ||
|
||
module.add_model_from_registry(MLModel.SUBTYPE, TorchMLModelModule.MODEL) | ||
await module.start() | ||
|
||
|
||
if __name__ == "__main__": | ||
asyncio.run(main()) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
from typing import List, Iterable, Dict, Any | ||
from numpy.typing import NDArray | ||
import torch.nn as nn | ||
from collections import OrderedDict | ||
from viam.logging import getLogger | ||
|
||
LOGGER = getLogger(__name__) | ||
|
||
|
||
class TorchModel: | ||
def __init__( | ||
self, | ||
path_to_serialized_file: str, | ||
model: nn.Module = None, | ||
) -> None: | ||
if model is not None: | ||
self.model = model | ||
else: | ||
self.model = torch.load(path_to_serialized_file) | ||
if not isinstance(self.model, nn.Module): | ||
if isinstance(self.model, OrderedDict): | ||
LOGGER.error( | ||
f"the file {path_to_serialized_file} provided as model file is of type collections.OrderedDict, which suggests that the provided file describes weights instead of a standalone model" | ||
) | ||
raise TypeError( | ||
f"the model is of type {type(self.model)} instead of nn.Module type" | ||
) | ||
self.model.eval() | ||
|
||
def infer(self, input): | ||
input = self.prepare_input(input) | ||
with torch.no_grad(): | ||
output = self.model(*input) | ||
return self.wrap_output(output) | ||
|
||
@staticmethod | ||
def prepare_input(input_tensor: Dict[str, NDArray]) -> List[NDArray]: | ||
return [torch.from_numpy(tensor) for tensor in input_tensor.values()] | ||
|
||
@staticmethod | ||
def wrap_output(output: Any) -> Dict[str, NDArray]: | ||
if isinstance(output, Iterable): | ||
if len(output) == 1: | ||
output = output[0] # unpack batched results | ||
|
||
if isinstance(output, torch.Tensor): | ||
return {"output_0": output.numpy()} | ||
|
||
elif isinstance(output, dict): | ||
for tensor_name, tensor in output.items(): | ||
if isinstance(tensor, torch.Tensor): | ||
output[tensor_name] = tensor.numpy() | ||
|
||
return output | ||
elif isinstance(output, Iterable): | ||
res = {} | ||
count = 0 | ||
for out in output: | ||
res[f"output_{count}"] = out | ||
count += 1 | ||
return res | ||
|
||
else: | ||
raise TypeError(f"can't convert output of type {type(output)} to array") |
Empty file.
Oops, something went wrong.