Skip to content

Commit

Permalink
Added OrtValue class
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 11, 2024
1 parent b0f606e commit ad0b54e
Show file tree
Hide file tree
Showing 3 changed files with 279 additions and 240 deletions.
243 changes: 3 additions & 240 deletions src/InferenceSession.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

class InferenceSession
{
use Utils;

private $ffi;
private $api;
private $session;
Expand Down Expand Up @@ -162,7 +164,7 @@ public function run($outputNames, $inputFeed, $logSeverityLevel = null, $logVerb

$output = [];
foreach ($outputTensor as $t) {
$output[] = $this->createFromOnnxValue($t);
$output[] = (new OrtValue($t))->toObject();
}

// TODO use finally
Expand Down Expand Up @@ -279,13 +281,6 @@ private function loadSession($path, $sessionOptions)
return $session;
}

private function loadAllocator()
{
$allocator = $this->ffi->new('OrtAllocator*');
$this->checkStatus(($this->api->GetAllocatorWithDefaultOptions)(\FFI::addr($allocator)));
return $allocator;
}

private function loadInputs()
{
$inputs = [];
Expand Down Expand Up @@ -453,189 +448,6 @@ private function cstring($str)
return $ptr;
}

private function createFromOnnxValue($outPtr)
{
try {
$outType = $this->ffi->new('ONNXType');
$this->checkStatus(($this->api->GetValueType)($outPtr, \FFI::addr($outType)));

if ($outType->cdata == $this->ffi->ONNX_TYPE_TENSOR) {
$typeinfo = $this->ffi->new('OrtTensorTypeAndShapeInfo*');
$this->checkStatus(($this->api->GetTensorTypeAndShape)($outPtr, \FFI::addr($typeinfo)));

[$type, $shape] = $this->tensorTypeAndShape($typeinfo);

// TODO skip if string
$tensorData = $this->ffi->new('void*');
$this->checkStatus(($this->api->GetTensorMutableData)($outPtr, \FFI::addr($tensorData)));

$outSize = $this->ffi->new('size_t');
$this->checkStatus(($this->api->GetTensorShapeElementCount)($typeinfo, \FFI::addr($outSize)));
$outputTensorSize = $outSize->cdata;

($this->api->ReleaseTensorTypeAndShapeInfo)($typeinfo);

$castTypes = $this->castTypes();
if (isset($castTypes[$type])) {
$arr = $this->ffi->cast($castTypes[$type] . "[$outputTensorSize]", $tensorData);
} elseif ($type == $this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) {
$arr = $this->createStringsFromOnnxValue($outPtr, $outputTensorSize);
} else {
$this->unsupportedType('element', $type);
}

$i = 0;
return $this->fillOutput($arr, $shape, $i);
} elseif ($outType->cdata == $this->ffi->ONNX_TYPE_SEQUENCE) {
$out = $this->ffi->new('size_t');
$this->checkStatus(($this->api->GetValueCount)($outPtr, \FFI::addr($out)));

$ret = [];
for ($i = 0; $i < $out->cdata; $i++) {
$seq = $this->ffi->new('OrtValue*');
$this->checkStatus(($this->api->GetValue)($outPtr, $i, $this->allocator, \FFI::addr($seq)));
$ret[] = $this->createFromOnnxValue($seq);
}
return $ret;
} elseif ($outType->cdata == $this->ffi->ONNX_TYPE_MAP) {
$typeShape = $this->ffi->new('OrtTensorTypeAndShapeInfo*');
$mapKeys = $this->ffi->new('OrtValue*');
$mapValues = $this->ffi->new('OrtValue*');
$elemType = $this->ffi->new('ONNXTensorElementDataType');

$this->checkStatus(($this->api->GetValue)($outPtr, 0, $this->allocator, \FFI::addr($mapKeys)));
$this->checkStatus(($this->api->GetValue)($outPtr, 1, $this->allocator, \FFI::addr($mapValues)));
$this->checkStatus(($this->api->GetTensorTypeAndShape)($mapKeys, \FFI::addr($typeShape)));
$this->checkStatus(($this->api->GetTensorElementType)($typeShape, \FFI::addr($elemType)));

($this->api->ReleaseTensorTypeAndShapeInfo)($typeShape);

// TODO support more types
if ($elemType->cdata == $this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
$ret = [];
$keys = $this->createFromOnnxValue($mapKeys);
$values = $this->createFromOnnxValue($mapValues);
return array_combine($keys, $values);
} else {
$this->unsupportedType('element', $elemType);
}
} else {
$this->unsupportedType('ONNX', $outType->cdata);
}
} finally {
if (!\FFI::isNull($outPtr)) {
($this->api->ReleaseValue)($outPtr);
}
}
}

private function fillOutput($ptr, $shape, &$i)
{
$dim = array_shift($shape);

if (count($shape) == 0) {
$row = [];
for ($j = 0; $j < $dim; $j++) {
$row[$j] = $ptr[$i];
$i++;
}
return $row;
} else {
$output = [];
for ($j = 0; $j < $dim; $j++) {
$output[] = $this->fillOutput($ptr, $shape, $i);
}
return $output;
}
}

private function createStringsFromOnnxValue($outPtr, $outputTensorSize)
{
$len = $this->ffi->new('size_t');
$this->checkStatus(($this->api->GetStringTensorDataLength)($outPtr, \FFI::addr($len)));

$sLen = $len->cdata;
$s = $this->ffi->new("char[$sLen]");
$offsets = $this->ffi->new("size_t[$outputTensorSize]");
$this->checkStatus(($this->api->GetStringTensorContent)($outPtr, $s, $sLen, $offsets, $outputTensorSize));

$result = [];
foreach ($offsets as $i => $v) {
$start = $v;
$end = $i < $outputTensorSize - 1 ? $offsets[$i + 1] : $sLen;
$size = $end - $start;
$result[] = \FFI::string($s + $start, $size);
}
return $result;
}

private static function checkStatus($status)
{
if (!is_null($status)) {
$message = (self::api()->GetErrorMessage)($status);
(self::api()->ReleaseStatus)($status);
throw new Exception($message);
}
}

private function nodeInfo($typeinfo)
{
$onnxType = $this->ffi->new('ONNXType');
$this->checkStatus(($this->api->GetOnnxTypeFromTypeInfo)($typeinfo, \FFI::addr($onnxType)));

if ($onnxType->cdata == $this->ffi->ONNX_TYPE_TENSOR) {
$tensorInfo = $this->ffi->new('OrtTensorTypeAndShapeInfo*');
// don't free tensor_info
$this->checkStatus(($this->api->CastTypeInfoToTensorInfo)($typeinfo, \FFI::addr($tensorInfo)));

[$type, $shape] = $this->tensorTypeAndShape($tensorInfo);
$elementDataType = $this->elementDataTypes()[$type];
return ['type' => "tensor($elementDataType)", 'shape' => $shape];
} elseif ($onnxType->cdata == $this->ffi->ONNX_TYPE_SEQUENCE) {
$sequenceTypeInfo = $this->ffi->new('OrtSequenceTypeInfo*');
$this->checkStatus(($this->api->CastTypeInfoToSequenceTypeInfo)($typeinfo, \FFI::addr($sequenceTypeInfo)));
$nestedTypeInfo = $this->ffi->new('OrtTypeInfo*');
$this->checkStatus(($this->api->GetSequenceElementType)($sequenceTypeInfo, \FFI::addr($nestedTypeInfo)));
$v = $this->nodeInfo($nestedTypeInfo)['type'];

return ['type' => "seq($v)", 'shape' => []];
} elseif ($onnxType->cdata == $this->ffi->ONNX_TYPE_MAP) {
$mapTypeInfo = $this->ffi->new('OrtMapTypeInfo*');
$this->checkStatus(($this->api->CastTypeInfoToMapTypeInfo)($typeinfo, \FFI::addr($mapTypeInfo)));

// key
$keyType = $this->ffi->new('ONNXTensorElementDataType');
$this->checkStatus(($this->api->GetMapKeyType)($mapTypeInfo, \FFI::addr($keyType)));
$k = $this->elementDataTypes()[$keyType->cdata];

// value
$valueTypeInfo = $this->ffi->new('OrtTypeInfo*');
$this->checkStatus(($this->api->GetMapValueType)($mapTypeInfo, \FFI::addr($valueTypeInfo)));
$v = $this->nodeInfo($valueTypeInfo)['type'];

return ['type' => "map($k,$v)", 'shape' => []];
} else {
$this->unsupportedType('ONNX', $onnxType->cdata);
}
}

private function castTypes()
{
return [
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => 'float',
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => 'uint8_t',
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => 'int8_t',
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => 'uint16_t',
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => 'int16_t',
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => 'int32_t',
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => 'int64_t',
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => 'bool',
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => 'double',
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => 'uint32_t',
$this->ffi->ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => 'uint64_t'
];
}

private function elementDataTypes()
{
return [
Expand All @@ -659,60 +471,11 @@ private function elementDataTypes()
];
}

private function tensorTypeAndShape($tensorInfo)
{
$type = $this->ffi->new('ONNXTensorElementDataType');
$this->checkStatus(($this->api->GetTensorElementType)($tensorInfo, \FFI::addr($type)));

$numDimsPtr = $this->ffi->new('size_t');
$this->checkStatus(($this->api->GetDimensionsCount)($tensorInfo, \FFI::addr($numDimsPtr)));
$numDims = $numDimsPtr->cdata;

if ($numDims > 0) {
$nodeDims = $this->ffi->new("int64_t[$numDims]");
$this->checkStatus(($this->api->GetDimensions)($tensorInfo, $nodeDims, $numDims));
$dims = $this->readArray($nodeDims);

$symbolicDims = $this->ffi->new("char*[$numDims]");
$this->checkStatus(($this->api->GetSymbolicDimensions)($tensorInfo, $symbolicDims, $numDims));
for ($i = 0; $i < $numDims; $i++) {
$namedDim = \FFI::string($symbolicDims[$i]);
if ($namedDim != '') {
$dims[$i] = $namedDim;
}
}
} else {
$dims = [];
}

return [$type->cdata, $dims];
}

private function unsupportedType($name, $type)
{
throw new Exception("Unsupported $name type: $type");
}

private function readArray($cdata)
{
$arr = [];
$n = count($cdata);
for ($i = 0; $i < $n; $i++) {
$arr[] = $cdata[$i];
}
return $arr;
}

private function allocatorFree($ptr)
{
($this->api->AllocatorFree)($this->allocator, $ptr);
}

private static function api()
{
return FFI::api();
}

// wide string on Windows
// char string on Linux
// see ORTCHAR_T in onnxruntime_c_api.h
Expand Down
Loading

0 comments on commit ad0b54e

Please sign in to comment.