Skip to content

Commit 36e64ca

Browse files
committed
Support structured output
1 parent 5b5cdc4 commit 36e64ca

11 files changed

+134
-19
lines changed

NEWS.md

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# ollamar (development version)
22

3+
- `generate()` and `chat()` support [structured output](https://ollama.com/blog/structured-outputs) via `format` parameter.
4+
- `test_connection()` returns boolean instead of `httr2` object. #29
5+
- `chat()` supports [tool calling](https://ollama.com/blog/tool-support) via `tools` parameter. Added `get_tool_calls()` helper function to process tools. #30
6+
- Simplify README and add Get started vignette with more examples.
7+
38
# ollamar 1.2.1
49

510
- `generate()` and `chat()` accept multiple images as prompts/messages.

R/ollama.R

+17-7
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ create_request <- function(endpoint, host = NULL) {
5858
#' @param prompt A character string of the prompt like "The sky is..."
5959
#' @param suffix A character string after the model response. Default is "".
6060
#' @param images A path to an image file to include in the prompt. Default is "".
61+
#' @param format Format to return a response in. Format can be json/list (structured response).
6162
#' @param system A character string of the system prompt (overrides what is defined in the Modelfile). Default is "".
6263
#' @param template A character string of the prompt template (overrides what is defined in the Modelfile). Default is "".
6364
#' @param context A list of context from a previous response to include previous conversation in the prompt. Default is an empty list.
@@ -86,10 +87,10 @@ create_request <- function(endpoint, host = NULL) {
8687
#' image_path <- file.path(system.file("extdata", package = "ollamar"), "image1.png")
8788
#' # use vision or multimodal model such as https://ollama.com/benzie/llava-phi-3
8889
#' generate("benzie/llava-phi-3:latest", "What is in the image?", images = image_path, output = "text")
89-
generate <- function(model, prompt, suffix = "", images = "", system = "", template = "", context = list(), stream = FALSE, raw = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req"), endpoint = "/api/generate", host = NULL, ...) {
90+
generate <- function(model, prompt, suffix = "", images = "", format = list(), system = "", template = "", context = list(), stream = FALSE, raw = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req", "structured"), endpoint = "/api/generate", host = NULL, ...) {
9091
output <- output[1]
91-
if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req")) {
92-
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text', 'req'")
92+
if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req", "structured")) {
93+
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text', 'req', 'structured'")
9394
}
9495

9596
req <- create_request(endpoint, host)
@@ -112,6 +113,10 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
112113
keep_alive = keep_alive
113114
)
114115

116+
if (length(format) != 0 & inherits(format, "list")) {
117+
body_json$format <- format
118+
}
119+
115120
# check if model options are passed and specified correctly
116121
opts <- list(...)
117122
if (length(opts) > 0) {
@@ -169,8 +174,9 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
169174
#' @param messages A list with list of messages for the model (see examples below).
170175
#' @param tools Tools for the model to use if supported. Requires stream = FALSE. Default is an empty list.
171176
#' @param stream Enable response streaming. Default is FALSE.
177+
#' @param format Format to return a response in. Format can be json/list (structured response).
172178
#' @param keep_alive The duration to keep the connection alive. Default is "5m".
173-
#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text", "req" (httr2_request object), "tools" (tool calling)
179+
#' @param output The output format. Default is "resp". Other options are "jsonlist", "raw", "df", "text", "req" (httr2_request object), "tools" (tool calling), "structured" (structured output)
174180
#' @param endpoint The endpoint to chat with the model. Default is "/api/chat".
175181
#' @param host The base URL to use. Default is NULL, which uses Ollama's default base URL.
176182
#' @param ... Additional options to pass to the model.
@@ -208,10 +214,10 @@ generate <- function(model, prompt, suffix = "", images = "", system = "", templ
208214
#' list(role = "user", content = "What is in the image?", images = image_path)
209215
#' )
210216
#' chat("benzie/llava-phi-3", messages, output = 'text')
211-
chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req", "tools"), endpoint = "/api/chat", host = NULL, ...) {
217+
chat <- function(model, messages, tools = list(), stream = FALSE, format = list(), keep_alive = "5m", output = c("resp", "jsonlist", "raw", "df", "text", "req", "tools", "structured"), endpoint = "/api/chat", host = NULL, ...) {
212218
output <- output[1]
213-
if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req", "tools")) {
214-
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text'")
219+
if (!output %in% c("df", "resp", "jsonlist", "raw", "text", "req", "tools", "structured")) {
220+
stop("Invalid output format specified. Supported formats: 'df', 'resp', 'jsonlist', 'raw', 'text', 'tools', 'structured'")
215221
}
216222

217223
req <- create_request(endpoint, host)
@@ -231,6 +237,10 @@ chat <- function(model, messages, tools = list(), stream = FALSE, keep_alive = "
231237
keep_alive = keep_alive
232238
)
233239

240+
if (length(format) != 0 & inherits(format, "list")) {
241+
body_json$format <- format
242+
}
243+
234244
opts <- list(...)
235245
if (length(opts) > 0) {
236246
if (validate_options(...)) {

R/utils.R

+5-2
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ get_tool_calls <- function(resp) {
112112
#' Process httr2 response object
113113
#'
114114
#' @param resp A httr2 response object.
115-
#' @param output The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text", "tools" (tool_calls)
115+
#' @param output The output format. Default is "df". Other options are "jsonlist", "raw", "resp" (httr2 response object), "text", "tools" (tool_calls), "structured" (structured output).
116116
#'
117117
#' @return A data frame, json list, raw or httr2 response object.
118118
#' @export
@@ -122,7 +122,6 @@ get_tool_calls <- function(resp) {
122122
#' resp_process(resp, "df") # parse response to dataframe/tibble
123123
#' resp_process(resp, "jsonlist") # parse response to list
124124
#' resp_process(resp, "raw") # parse response to raw string
125-
#' resp_process(resp, "resp") # return input response object
126125
#' resp_process(resp, "text") # return text/character vector
127126
#' resp_process(resp, "tools") # return tool_calls
128127
resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text", "tools")) {
@@ -195,6 +194,8 @@ resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text
195194
return(df_response)
196195
} else if (output == "text") {
197196
return(df_response$response)
197+
} else if (output == "structured") {
198+
return(jsonlite::fromJSON(df_response$response))
198199
}
199200
} else if (grepl("api/chat", resp$url)) { # process chat endpoint
200201
json_body <- httr2::resp_body_json(resp)
@@ -209,6 +210,8 @@ resp_process <- function(resp, output = c("df", "jsonlist", "raw", "resp", "text
209210
return(df_response)
210211
} else if (output == "text") {
211212
return(df_response$content)
213+
} else if (output == "structured") {
214+
return(jsonlite::fromJSON(df_response$content))
212215
}
213216
} else if (grepl("api/tags", resp$url)) { # process tags endpoint
214217
json_body <- httr2::resp_body_json(resp)[[1]]

README.Rmd

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The library also makes it easy to work with data structures (e.g., conversationa
2727

2828
To use this R library, ensure the [Ollama](https://ollama.com) app is installed. Ollama can use GPUs for accelerating LLM inference. See [Ollama GPU documentation](https://github.com/ollama/ollama/blob/main/docs/gpu.md) for more information.
2929

30-
See [Ollama's Github page](https://github.com/ollama/ollama) for more information. This library uses the [Ollama REST API (see documentation for details)](https://github.com/ollama/ollama/blob/main/docs/api.md) and has been tested on Ollama v0.1.30 and above. It was last tested on Ollama v0.3.10.
30+
See [Ollama's Github page](https://github.com/ollama/ollama) for more information. This library uses the [Ollama REST API (see documentation for details)](https://github.com/ollama/ollama/blob/main/docs/api.md) and was last tested on v0.5.4.
3131

3232
> Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
3333

README.md

+1-2
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ for more information.
3131
See [Ollama’s Github page](https://github.com/ollama/ollama) for more
3232
information. This library uses the [Ollama REST API (see documentation
3333
for details)](https://github.com/ollama/ollama/blob/main/docs/api.md)
34-
and has been tested on Ollama v0.1.30 and above. It was last tested on
35-
Ollama v0.3.10.
34+
and was last tested on v0.5.4.
3635

3736
> Note: You should have at least 8 GB of RAM available to run the 7B
3837
> models, 16 GB to run the 13B models, and 32 GB to run the 33B models.

man/chat.Rd

+5-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/generate.Rd

+4-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/resp_process.Rd

+1-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-chat.R

+37-2
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ test_that("chat function handles images in messages", {
130130
})
131131

132132

133-
test_that("chat function handles tools", {
133+
test_that("chat function tool calling", {
134134
skip_if_not(test_connection(), "Ollama server not available")
135135

136136
add_two_numbers <- function(x, y) {
@@ -188,7 +188,7 @@ test_that("chat function handles tools", {
188188

189189

190190
# test multiple tools
191-
msg <- create_message("what is three plus one? then multiply the output of that by ten")
191+
msg <- create_message("add three plus four. then multiply by ten")
192192
tools <- list(list(type = "function",
193193
"function" = list(
194194
name = "add_two_numbers",
@@ -234,3 +234,38 @@ test_that("chat function handles tools", {
234234
# expect_equal(resp[[2]]$name, "multiply_two_numbers")
235235

236236
})
237+
238+
239+
240+
241+
test_that("structured output", {
242+
skip_if_not(test_connection(), "Ollama server not available")
243+
244+
format <- list(
245+
type = "object",
246+
properties = list(
247+
name = list(
248+
type = "string"
249+
),
250+
capital = list(
251+
type = "string"
252+
),
253+
languages = list(
254+
type = "array",
255+
items = list(
256+
type = "string"
257+
)
258+
)
259+
),
260+
required = list("name", "capital", "languages")
261+
)
262+
263+
msg <- create_message("tell me about canada")
264+
resp <- chat("llama3.1", msg, format = format)
265+
# content <- httr2::resp_body_json(resp)$message$content
266+
structured_output <- resp_process(resp, "structured")
267+
expect_equal(tolower(structured_output$name), "canada")
268+
269+
})
270+
271+

tests/testthat/test-generate.R

+34
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,37 @@ test_that("generate function works with images", {
7979
})
8080

8181

82+
83+
84+
85+
test_that("structured output", {
86+
skip_if_not(test_connection(), "Ollama server not available")
87+
88+
format <- list(
89+
type = "object",
90+
properties = list(
91+
name = list(
92+
type = "string"
93+
),
94+
capital = list(
95+
type = "string"
96+
),
97+
languages = list(
98+
type = "array",
99+
items = list(
100+
type = "string"
101+
)
102+
)
103+
),
104+
required = list("name", "capital", "languages")
105+
)
106+
107+
msg <- "tell me about canada"
108+
resp <- generate("llama3.1", prompt = msg, format = format)
109+
# response <- httr2::resp_body_json(resp)$response
110+
structured_output <- resp_process(resp, "structured")
111+
expect_equal(tolower(structured_output$name), "canada")
112+
113+
})
114+
115+

vignettes/ollamar.Rmd

+24
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,30 @@ do.call(resp[[1]]$name, resp[[1]]$arguments) # 7
362362
do.call(resp[[2]]$name, resp[[2]]$arguments) # 70
363363
```
364364

365+
### Structured outputs
366+
367+
The `chat()` and `generate()` functions support [structured outputs](https://ollama.com/blog/structured-outputs), making it possible to constrain a model's output to a specified format defined by a JSON schema (R list).
368+
369+
```{r eval=FALSE}
370+
# define a JSON schema as a list to constrain a model's output
371+
format <- list(
372+
type = "object",
373+
properties = list(
374+
name = list(type = "string"),
375+
capital = list(type = "string"),
376+
languages = list(type = "array",
377+
items = list(type = "string")
378+
)
379+
),
380+
required = list("name", "capital", "languages")
381+
)
382+
383+
generate("llama3.1", "tell me about Canada", output = "structured", format = format)
384+
385+
msg <- create_message("tell me about Canada")
386+
chat("llama3.1", msg, format = format, output = "structured")
387+
```
388+
365389
### Parallel requests
366390

367391
For the `generate()` and `chat()` endpoints/functions, you can specify `output = 'req'` in the function so the functions return `httr2_request` objects instead of `httr2_response` objects.

0 commit comments

Comments
 (0)