Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference is failed after loading the engine file #3

Open
oscarriddle opened this issue May 24, 2019 · 4 comments
Open

Inference is failed after loading the engine file #3

oscarriddle opened this issue May 24, 2019 · 4 comments

Comments

@oscarriddle
Copy link

Hi,

Here is my environment setting:

CentOS 7.0
PyTorch 1.1.0
TensorRT 5.1.2.2 with CUDA9.0
CUDA9.0
cuDNN7.5.0 with CUDA9.0
Python3.6.8

After installed the dependencies of this repo.
I tried the test.py, and serialized the engine into file "test.engine"

Then I try to load it and execute the inference as normal tensorrt python code, as below:

def get_engine(engine_file):
        with open(engine_file,"rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
                print("Engine Loaded")
                return runtime.deserialize_cuda_engine(f.read())
        return None

if __name__ == '__main__':
    img = np.random.rand(1, 3,299,299)
    img /= 255.0
    img -= 0.5
    img *= 2.0
    bindings = []
    img = np.ascontiguousarray(img)                                                                                                                                
    engine = get_engine('test.engine')                                                                                                                            
    stream = cuda.Stream()
    context = engine.create_execution_context()
    output = np.empty(1000, dtype = np.float32)
    d_input = cuda.mem_alloc(1 * img.nbytes)
    d_output = cuda.mem_alloc(1 * output.nbytes)
    bindings = [int(d_input), int(d_output)]
    cuda.memcpy_htod_async(d_input, img, stream)
    context.execute_async(1, bindings, stream.handle, None)
    cuda.memcpy_dtoh_async(output, d_output, stream)
    stream.synchronize()

Then I get below error:

$ python3 test_run.py
Engine Loaded
[TensorRT] WARNING: TensorRT was compiled against cuDNN 7.5.0 but is linked against cuDNN 7.5.1. This mismatch may potentially cause undefined behavior.
[TensorRT] WARNING: TensorRT was compiled against cuDNN 7.5.0 but is linked against cuDNN 7.5.1. This mismatch may potentially cause undefined behavior.
[TensorRT] ERROR: Parameter check failed at: engine.cpp::enqueue::451, condition: bindings[x] != nullptr

Any advice will be welcome.
Thanks,

@traveller59
Copy link
Owner

traveller59 commented May 26, 2019

example in readme generate a engine with two outputs. so you need to alloc memory for second output and add it to bindings.
I recommend to write a simple high-level API based on allocate_buffers in common.py in tensorrt examples to decrease these kind of bugs. you can get all shapes, dtypes and names of inputs and outputs from a engine instance, use them to write API.

@oscarriddle
Copy link
Author

oscarriddle commented May 27, 2019

Hi @traveller59
Thanks for your reply.
After update the 2 outputs to the test script, the error disappeared.

import torch
import torchvision
import tensorrt as trt
import torch2trt
import time
import numpy as np
import pycuda.driver as cuda
import pycuda.autoinit
if __name__ == '__main__':
    img = np.random.rand(1, 3, 299, 299)
    #img /= 255.0                                                                                                                            
    #img -= 0.5                                                                                                                              
    #img *= 2.0                                                                                                                              
    bindings = []
    img = np.ascontiguousarray(img)
    #out = infer(get_engine('test.engine'), input, 1)                                                                                        
    engine = get_engine('test.engine')
    #runtime = trt.infer.create_infer_runtime(G_LOGGER)                                                                                      
    stream = cuda.Stream()
    context = engine.create_execution_context()

    output = np.empty(1000, dtype = np.float32)
    output2 = np.empty(1000, dtype = np.float32)
    d_input = cuda.mem_alloc(1 * img.nbytes)
    d_output = cuda.mem_alloc(1 * output.nbytes)
    d_output2 = cuda.mem_alloc(1 * output2.nbytes)
    bindings = [int(d_input), int(d_output), int(d_output2)]

    for i in range(100):
        a1 = time.time()
        cuda.memcpy_htod_async(d_input, img, stream)
        context.execute_async(1, bindings, stream.handle, None)
        cuda.memcpy_dtoh_async(output, d_output, stream)
        cuda.memcpy_dtoh_async(output2, d_output2, stream)
        stream.synchronize()
        a2 = time.time()
        print('Batch {}-th, Time {}ms'.format(i, a2-a1))
    print(output)
    print(output2)

I snooped the data in array output and output2, the output2 is all zero and the output is as below:

[1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0. 1. 0. 1. 1. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0.1. 0. 0. 1. 1. 0. 0. 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.0. 1. 0. 0. 0. 1. 1. 0. 0. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 1. 1. 1. 0. 0. 1. 1. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 0. 1. 0. 1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 1. 0. 1. 0. 0. 1. 0. 0. 1. 1. 0. 0. 0. 1. 0. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 0. 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 0. 0. 1. 1. 0. 1. 0. 1. 1. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 1. 1. 1. 0. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 0. 0. 0. 1. 1. 0. 1. 0. 0. 1. 1. 1. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 1. 0. 1. 0. 1. 1. 0. 0. 1. 1. 0. 1. 1. 0. 1. 0. 1. 0. 1. 1. 0. 0. 1. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 0. 1. 0. 1. 1. 0. 1. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 1. 1. 1. 0. 1. 1. 0. 1. 1. 1. 1. 1. 0. 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 0. 0. 0. 0. 1. 1. 1. 1. 1. 0. 1. 0. 0. 1. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0. 0. 1. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 1. 0. 1. 0. 0. 1. 0. 0. 1. 0. 1. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 1. 1. 1. 0. 1. 1. 1. 0. 0. 1. 1. 0. 0. 1. 0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1. 0. 1. 0. 0. 1. 0. 0. 0. 0. 1. 0. 1. 1. 1. 0. 1. 0. 0. 1. 1. 0. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0. 0. 0. 1. 0. 0. 1. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 1. 1. 0. 0. 1. 0. 0. 0. 1. 1. 0. 0. 1. 1. 1. 1. 1. 0. 1. 0. 1. 1. 1. 1. 0. 0. 0. 1. 0. 0. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 1. 0. 0. 0. 1. 1. 1. 0. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 1. 0. 0. 1. 0. 1. 1. 1. 1. 0. 1. 0. 0. 0. 1. 0. 1. 0. 1. 1. 0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 1. 0. 1. 0. 1. 1. 1. 1. 1. 1. 0. 1. 1. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 1. 0. 1. 1. 0. 1. 0. 0. 1. 1. 0. 1. 1. 1. 0. 0. 0. 0. 1. 0. 1. 1. 1. 0. 0. 1. 1. 0. 1. 1. 0. 1. 0. 0. 1. 0. 1. 1. 1. 1. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 1. 1. 0. 1. 0. 1. 0. 1. 0. 0. 1. 1. 0. 0. 0. 1. 1. 1. 1. 1. 0. 0. 1. 1. 1. 1. 1. 1. 0. 1. 0. 0. 1. 0. 0. 1. 1. 0. 0. 0. 0. 1. 1. 0. 1. 0. 1. 1. 0. 0. 1. 0. 0. 1. 1. 1. 1. 1. 1. 0. 0. 0. 1. 1. 0. 0. 0. 1. 0. 1. 0. 1. 1. 1. 1. 0. 0. 1. 1. 0. 1. 0. 0. 0. 1. 1. 0. 0. 1. 0. 1. 0. 1. 1. 1. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 0. 1. 0. 0. 0. 1. 1. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 1. 0. 1. 0. 1. 1. 1. 0. 0. 0. 1. 1. 0. 1. 1. 1. 1. 1. 0. 0. 1. 0. 1. 1. 1. 1. 0. 0. 0. 1. 0. 0. 1. 0. 1. 0. 0. 0. 0. 0. 0. 1. 1. 0. 1. 1. 0. 0. 0. 1. 1. 1. 0. 1. 0. 1. 1. 1. 1. 0. 1. 0. 0.]

Meanwhile, I randomized the same shape input by torch.rand().cuda() and input it into the pytorch model, and the output tensor is as below:

tensor([[-1.1718e+00,  4.6266e-01,  1.0212e+00,  2.2632e-01,  4.7583e-01,
         -6.3883e-01,  1.0015e-01,  1.5616e+00,  1.0254e+00,  1.1679e+00,
          1.8992e+00,  2.9743e+00,  3.4913e+00,  2.1630e+00,  2.3985e+00,
          2.6715e+00,  3.0376e+00,  1.4133e+00,  3.0594e+00,  2.1019e+00,
          1.1495e+00,  5.6692e+00,  3.9588e+00,  3.3769e+00,  2.3353e+00,
         -1.0645e+00, -7.9425e-01, -8.1189e-01, -1.1870e+00, -5.1096e-01,
         -8.8577e-02,  2.2844e-01, -1.5711e+00, -2.0256e-01, -3.3481e-01,
...

Seems my conversion configuration is incorrect, would you give some advice?
Thanks

Below is my conversion script:

import torch
import torchvision
import tensorrt as trt
import torch2trt
import time
import numpy

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)

net = torchvision.models.inception_v3(pretrained=True).eval()
inputs = torch.rand(1, 3, 299, 299)
graph_pth = torch2trt.GraphModule(net, inputs, param_exclude=".*AuxLogits.*")                                                                                       
torch_mode_out = graph_pth(inputs)                                                                                           
def toy_example(x):
    return torch.softmax(x, 1), torch.sigmoid(x)
graph_pth_toy = torch2trt.GraphModule(toy_example, torch_mode_out)
probs, sigmoid = graph_pth_toy(torch_mode_out, verbose=True)

with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as trt_net:
    builder.max_workspace_size = 1 << 30                                                                   
    with torch2trt.trt_network(trt_net): # must use this to enter trt mode                                                                   
        img = trt_net.add_input(name="image", shape=[3, 299, 299], dtype=trt.float32)
        trt_mode_out = graph_pth(img, verbose=True) # call graph_pth like torch module call                                                                                                                                     
        trt_mode_out, sigmoid = graph_pth_toy(trt_mode_out)
    trt_mode_out.name = "output_softmax"
    sigmoid.name = "output_sigmoid"
    trt_net.mark_output(tensor=trt_mode_out)
    trt_net.mark_output(tensor=sigmoid)
    engine = builder.build_cuda_engine(trt_net)                                                                                                    
    with open("test.engine", "wb") as f:
        f.write(engine.serialize())

@traveller59
Copy link
Owner

you need to convert image to np.float32.
please don't use raw API... you can try my high-level API in newest code.

@oscarriddle
Copy link
Author

oscarriddle commented May 28, 2019

Hi, I tried your newest inference code and get the output results.
I noticed you compared different results by norm to check the coherency.
I also tried to import the exactly same input to original pytorch method, like below

import torch
import torchvision
import tensorrt as trt
import torch2trt
import time
import numpy as np

TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
net = torchvision.models.inception_v3(pretrained=True).eval()                                                                                 
img = np.load('input_raw.bin.npy')
inputs = torch.from_numpy(img)                                                                                                    
model = net.cuda()
inp = inputs.cuda()
for i in range(1):
    a1 = time.time()
    out = model(inp)
    a2 = time.time()
    print('{}, {}, {}'.format(i, out, a2-a1))

But got a different result compared to the tensorrt way. Would you leave some comments about how to address this issue?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants