Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bark-cpp): add new bark.cpp backend #4287

Merged
merged 5 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions .github/workflows/bump_deps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,12 @@ jobs:
- repository: "ggerganov/llama.cpp"
variable: "CPPLLAMA_VERSION"
branch: "master"
- repository: "go-skynet/go-ggml-transformers.cpp"
variable: "GOGGMLTRANSFORMERS_VERSION"
branch: "master"
- repository: "donomii/go-rwkv.cpp"
variable: "RWKV_VERSION"
branch: "main"
- repository: "ggerganov/whisper.cpp"
variable: "WHISPER_CPP_VERSION"
branch: "master"
- repository: "go-skynet/go-bert.cpp"
variable: "BERT_VERSION"
branch: "master"
- repository: "go-skynet/bloomz.cpp"
variable: "BLOOMZ_VERSION"
- repository: "PABannier/bark.cpp"
variable: "BARKCPP_VERSION"
branch: "main"
- repository: "mudler/go-ggllm.cpp"
variable: "GOGGLLM_VERSION"
branch: "master"
- repository: "mudler/go-stable-diffusion"
variable: "STABLEDIFFUSION_VERSION"
branch: "master"
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
/sources/
__pycache__/
*.a
*.o
get-sources
prepare-sources
/backend/cpp/llama/grpc-server
Expand Down
37 changes: 36 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ STABLEDIFFUSION_VERSION?=4a3cd6aeae6f66ee57eae9a0075f8c58c3a6a38f
TINYDREAM_REPO?=https://github.com/M0Rf30/go-tiny-dream
TINYDREAM_VERSION?=c04fa463ace9d9a6464313aa5f9cd0f953b6c057

# bark.cpp
BARKCPP_REPO?=https://github.com/PABannier/bark.cpp.git
BARKCPP_VERSION?=v1.0.0

ONNX_VERSION?=1.20.0
ONNX_ARCH?=x64
ONNX_OS?=linux
Expand Down Expand Up @@ -201,6 +205,13 @@ ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-ggml
ALL_GRPC_BACKENDS+=backend-assets/grpc/llama-cpp-grpc
ALL_GRPC_BACKENDS+=backend-assets/util/llama-cpp-rpc-server
ALL_GRPC_BACKENDS+=backend-assets/grpc/whisper

ifeq ($(ONNX_OS),linux)
ifeq ($(ONNX_ARCH),x64)
ALL_GRPC_BACKENDS+=backend-assets/grpc/bark-cpp
endif
endif

ALL_GRPC_BACKENDS+=backend-assets/grpc/local-store
ALL_GRPC_BACKENDS+=backend-assets/grpc/silero-vad
ALL_GRPC_BACKENDS+=$(OPTIONAL_GRPC)
Expand Down Expand Up @@ -233,6 +244,22 @@ sources/go-llama.cpp:
git checkout $(GOLLAMA_VERSION) && \
git submodule update --init --recursive --depth 1 --single-branch

sources/bark.cpp:
git clone --recursive https://github.com/PABannier/bark.cpp.git sources/bark.cpp && \
cd sources/bark.cpp && \
git checkout $(BARKCPP_VERSION) && \
git submodule update --init --recursive --depth 1 --single-branch

sources/bark.cpp/build/libbark.a: sources/bark.cpp
cd sources/bark.cpp && \
mkdir build && \
cd build && \
cmake $(CMAKE_ARGS) .. && \
cmake --build . --config Release

backend/go/bark/libbark.a: sources/bark.cpp/build/libbark.a
$(MAKE) -C backend/go/bark libbark.a

sources/go-llama.cpp/libbinding.a: sources/go-llama.cpp
$(MAKE) -C sources/go-llama.cpp BUILD_TYPE=$(STABLE_BUILD_TYPE) libbinding.a

Expand Down Expand Up @@ -302,7 +329,7 @@ sources/whisper.cpp:
sources/whisper.cpp/libwhisper.a: sources/whisper.cpp
cd sources/whisper.cpp && $(MAKE) libwhisper.a libggml.a

get-sources: sources/go-llama.cpp sources/go-piper sources/whisper.cpp sources/go-stable-diffusion sources/go-tiny-dream backend/cpp/llama/llama.cpp
get-sources: sources/go-llama.cpp sources/go-piper sources/bark.cpp sources/whisper.cpp sources/go-stable-diffusion sources/go-tiny-dream backend/cpp/llama/llama.cpp

replace:
$(GOCMD) mod edit -replace github.com/ggerganov/whisper.cpp=$(CURDIR)/sources/whisper.cpp
Expand Down Expand Up @@ -343,6 +370,7 @@ clean: ## Remove build related file
rm -rf release/
rm -rf backend-assets/*
$(MAKE) -C backend/cpp/grpc clean
$(MAKE) -C backend/go/bark clean
$(MAKE) -C backend/cpp/llama clean
rm -rf backend/cpp/llama-* || true
$(MAKE) dropreplace
Expand Down Expand Up @@ -792,6 +820,13 @@ ifneq ($(UPX),)
$(UPX) backend-assets/grpc/llama-ggml
endif

backend-assets/grpc/bark-cpp: backend/go/bark/libbark.a backend-assets/grpc
CGO_LDFLAGS="$(CGO_LDFLAGS)" C_INCLUDE_PATH=$(CURDIR)/backend/go/bark/ LIBRARY_PATH=$(CURDIR)/backend/go/bark/ \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/bark-cpp ./backend/go/bark/
ifneq ($(UPX),)
$(UPX) backend-assets/grpc/bark-cpp
endif

backend-assets/grpc/piper: sources/go-piper sources/go-piper/libpiper_binding.a backend-assets/grpc backend-assets/espeak-ng-data
CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURDIR)/sources/go-piper \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/piper ./backend/go/tts/
Expand Down
25 changes: 25 additions & 0 deletions backend/go/bark/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
INCLUDE_PATH := $(abspath ./)
LIBRARY_PATH := $(abspath ./)

AR?=ar

BUILD_TYPE?=
# keep standard at C11 and C++11
CXXFLAGS = -I. -I$(INCLUDE_PATH)/../../../sources/bark.cpp/examples -I$(INCLUDE_PATH)/../../../sources/bark.cpp/spm-headers -I$(INCLUDE_PATH)/../../../sources/bark.cpp -O3 -DNDEBUG -std=c++17 -fPIC
LDFLAGS = -L$(LIBRARY_PATH) -L$(LIBRARY_PATH)/../../../sources/bark.cpp/build/examples -lbark -lstdc++ -lm

# warnings
CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function

gobark.o:
$(CXX) $(CXXFLAGS) gobark.cpp -o gobark.o -c $(LDFLAGS)

libbark.a: gobark.o
cp $(INCLUDE_PATH)/../../../sources/bark.cpp/build/libbark.a ./
$(AR) rcs libbark.a gobark.o
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml.c.o
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-alloc.c.o
$(AR) rcs libbark.a $(LIBRARY_PATH)/../../../sources/bark.cpp/build/encodec.cpp/ggml/src/CMakeFiles/ggml.dir/ggml-backend.c.o

clean:
rm -f gobark.o libbark.a
85 changes: 85 additions & 0 deletions backend/go/bark/gobark.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include <iostream>
#include <tuple>

#include "bark.h"
#include "gobark.h"
#include "common.h"
#include "ggml.h"

struct bark_context *c;

void bark_print_progress_callback(struct bark_context *bctx, enum bark_encoding_step step, int progress, void *user_data) {
if (step == bark_encoding_step::SEMANTIC) {
printf("\rGenerating semantic tokens... %d%%", progress);
} else if (step == bark_encoding_step::COARSE) {
printf("\rGenerating coarse tokens... %d%%", progress);
} else if (step == bark_encoding_step::FINE) {
printf("\rGenerating fine tokens... %d%%", progress);
}
fflush(stdout);
}

int load_model(char *model) {
// initialize bark context
struct bark_context_params ctx_params = bark_context_default_params();
bark_params params;

params.model_path = model;

// ctx_params.verbosity = verbosity;
ctx_params.progress_callback = bark_print_progress_callback;
ctx_params.progress_callback_user_data = nullptr;

struct bark_context *bctx = bark_load_model(params.model_path.c_str(), ctx_params, params.seed);
if (!bctx) {
fprintf(stderr, "%s: Could not load model\n", __func__);
return 1;
}

c = bctx;

return 0;
}

int tts(char *text,int threads, char *dst ) {

ggml_time_init();
const int64_t t_main_start_us = ggml_time_us();

// generate audio
if (!bark_generate_audio(c, text, threads)) {
fprintf(stderr, "%s: An error occured. If the problem persists, feel free to open an issue to report it.\n", __func__);
return 1;
}

const float *audio_data = bark_get_audio_data(c);
if (audio_data == NULL) {
fprintf(stderr, "%s: Could not get audio data\n", __func__);
return 1;
}

const int audio_arr_size = bark_get_audio_data_size(c);

std::vector<float> audio_arr(audio_data, audio_data + audio_arr_size);

write_wav_on_disk(audio_arr, dst);

// report timing
{
const int64_t t_main_end_us = ggml_time_us();
const int64_t t_load_us = bark_get_load_time(c);
const int64_t t_eval_us = bark_get_eval_time(c);

printf("\n\n");
printf("%s: load time = %8.2f ms\n", __func__, t_load_us / 1000.0f);
printf("%s: eval time = %8.2f ms\n", __func__, t_eval_us / 1000.0f);
printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
}

return 0;
}

int unload() {
bark_free(c);
}

52 changes: 52 additions & 0 deletions backend/go/bark/gobark.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package main

// #cgo CXXFLAGS: -I${SRCDIR}/../../../sources/bark.cpp/ -I${SRCDIR}/../../../sources/bark.cpp/encodec.cpp -I${SRCDIR}/../../../sources/bark.cpp/examples -I${SRCDIR}/../../../sources/bark.cpp/spm-headers
// #cgo LDFLAGS: -L${SRCDIR}/ -L${SRCDIR}/../../../sources/bark.cpp/build/examples -L${SRCDIR}/../../../sources/bark.cpp/build/encodec.cpp/ -lbark -lencodec -lcommon
// #include <gobark.h>
// #include <stdlib.h>
import "C"

import (
"fmt"
"unsafe"

"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
)

type Bark struct {
base.SingleThread
threads int
}

func (sd *Bark) Load(opts *pb.ModelOptions) error {

sd.threads = int(opts.Threads)

modelFile := C.CString(opts.ModelFile)
defer C.free(unsafe.Pointer(modelFile))

Check warning

Code scanning / gosec

Use of unsafe calls should be audited Warning

Use of unsafe calls should be audited

ret := C.load_model(modelFile)
if ret != 0 {
return fmt.Errorf("inference failed")
}

return nil
}

func (sd *Bark) TTS(opts *pb.TTSRequest) error {
t := C.CString(opts.Text)
defer C.free(unsafe.Pointer(t))

Check warning

Code scanning / gosec

Use of unsafe calls should be audited Warning

Use of unsafe calls should be audited

dst := C.CString(opts.Dst)
defer C.free(unsafe.Pointer(dst))

Check warning

Code scanning / gosec

Use of unsafe calls should be audited Warning

Use of unsafe calls should be audited

threads := C.int(sd.threads)

ret := C.tts(t, threads, dst)
if ret != 0 {
return fmt.Errorf("inference failed")
}

return nil
}
8 changes: 8 additions & 0 deletions backend/go/bark/gobark.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifdef __cplusplus
extern "C" {
#endif
int load_model(char *model);
int tts(char *text,int threads, char *dst );
#ifdef __cplusplus
}
#endif
20 changes: 20 additions & 0 deletions backend/go/bark/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package main

// Note: this is started internally by LocalAI and a server is allocated for each model
import (
"flag"

grpc "github.com/mudler/LocalAI/pkg/grpc"
)

var (
addr = flag.String("addr", "localhost:50051", "the address to connect to")
)

func main() {
flag.Parse()

if err := grpc.StartServer(*addr, &Bark{}); err != nil {
panic(err)
}
}
67 changes: 0 additions & 67 deletions core/http/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ import (
"context"
"embed"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"

"github.com/mudler/LocalAI/core/config"
. "github.com/mudler/LocalAI/core/http"
Expand Down Expand Up @@ -913,71 +911,6 @@ var _ = Describe("API test", func() {
})
})

Context("backends", func() {
It("runs rwkv completion", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,"})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(resp.Choices[0].Text).To(ContainSubstring("five"))

stream, err := client.CreateCompletionStream(context.TODO(), openai.CompletionRequest{
Model: "rwkv_test", Prompt: "Count up to five: one, two, three, four,", Stream: true,
})
Expect(err).ToNot(HaveOccurred())
defer stream.Close()

tokens := 0
text := ""
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}

Expect(err).ToNot(HaveOccurred())
text += response.Choices[0].Text
tokens++
}
Expect(text).ToNot(BeEmpty())
Expect(text).To(ContainSubstring("five"))
Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
It("runs rwkv chat completion", func() {
if runtime.GOOS != "linux" {
Skip("test supported only on linux")
}
resp, err := client.CreateChatCompletion(context.TODO(),
openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Choices) > 0).To(BeTrue())
Expect(strings.ToLower(resp.Choices[0].Message.Content)).To(Or(ContainSubstring("sure"), ContainSubstring("five"), ContainSubstring("5")))

stream, err := client.CreateChatCompletionStream(context.TODO(), openai.ChatCompletionRequest{Model: "rwkv_test", Messages: []openai.ChatCompletionMessage{{Content: "Can you count up to five?", Role: "user"}}})
Expect(err).ToNot(HaveOccurred())
defer stream.Close()

tokens := 0
text := ""
for {
response, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
}

Expect(err).ToNot(HaveOccurred())
text += response.Choices[0].Delta.Content
tokens++
}
Expect(text).ToNot(BeEmpty())
Expect(strings.ToLower(text)).To(Or(ContainSubstring("sure"), ContainSubstring("five")))

Expect(tokens).ToNot(Or(Equal(1), Equal(0)))
})
})

// See tests/integration/stores_test
Context("Stores", Label("stores"), func() {

Expand Down
Loading