Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ setup-python:
python3 -m venv .venv && \
source .venv/bin/activate && \
pip install --upgrade pip && \
pip install torch transformers
pip install torch transformers accelerate

build-granite:
go build -o $(GRANITE_RUNNER) model/granite-runner.go
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ bin/tree-ai ./ --verbose

```bash
❯ bin/tree-ai ./ --endpoint="<model endpoint>" --truncate
⚠️ AI-generated summaries may be inaccurate or outdated. Always verify important details.
⚠️ AI-generated summaries may be inaccurate or outdated.
└── LICENSE ➤ grants users permission to use, modify, and distribute the project's software
└── Makefile ➤ as a build and testing automation tool for the tree-ai project
└── README.md ➤ This file serves as the project's documentation and user guide
Expand Down
2 changes: 1 addition & 1 deletion cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ var rootCmd = &cobra.Command{
Use: "tree-ai",
Short: "AI-enhanced tree command",
Run: func(cmd *cobra.Command, args []string) {
fmt.Fprintln(os.Stdout, "⚠️ AI-generated summaries may be inaccurate or outdated. Always verify important details.")
fmt.Fprintln(os.Stdout, "⚠️ AI-generated summaries may be inaccurate or outdated.")
dir := "."
if len(args) > 0 {
dir = args[0]
Expand Down
234 changes: 170 additions & 64 deletions internal/ai/describe.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai

import (
"bufio"
"bytes"
"encoding/json"
"fmt"
Expand All @@ -11,23 +12,32 @@ import (
"path/filepath"
"regexp"
"strings"
"sync"
"time"
)

var fileCounter int
var totalFiles int
var (
fileCounter int
totalFiles int
localModelProcess *exec.Cmd
Verbose bool = false
TruncateDescriptions bool = true
modelWriter io.WriteCloser
modelReader *bufio.Reader
modelLock sync.Mutex
)

func SetTotalFiles(n int) {
totalFiles = n
}

var Verbose bool = false
var TruncateDescriptions bool = true

func Describe(path string, isDir bool, model, userEndpoint, userInstruction string) string {
itemType := map[bool]string{true: "directory", false: "file"}[isDir]
target := filepath.Base(path)
content := collectContent(path, isDir)
if Verbose {
fmt.Fprintf(os.Stderr, "[tree-ai] collected content for %s (%s):\n%s\n", path, itemType, content)
}

instruction := userInstruction
if instruction == "" {
Expand All @@ -49,17 +59,9 @@ This is a %s named "%s". Its contents are:
endpoint = os.Getenv("TREE_AI_ENDPOINT")
}

if endpoint == "" {
if Verbose {
fmt.Fprintln(os.Stderr, "[tree-ai] no remote endpoint configured, falling back to local model.")
}
return formatFinalResponse(target, cleanModelResponse(fallback(target, isDir, model, prompt), target, isDir), isDir)
}

healthURL := strings.Replace(endpoint, "/v1/completions", "/health", 1)
if !isEndpointAvailable(healthURL) {
if endpoint == "" || !isEndpointAvailable(strings.Replace(endpoint, "/v1/completions", "/health", 1)) {
if Verbose {
fmt.Fprintf(os.Stderr, "[tree-ai] remote endpoint %s not available, falling back to local model.\n", endpoint)
fmt.Fprintln(os.Stderr, "[tree-ai] falling back to local model.")
}
return formatFinalResponse(target, cleanModelResponse(fallback(target, isDir, model, prompt), target, isDir), isDir)
}
Expand Down Expand Up @@ -104,6 +106,125 @@ This is a %s named "%s". Its contents are:
return formatFinalResponse(target, cleanModelResponse(result.Choices[0].Text, target, isDir), isDir)
}

func fallback(target string, isDir bool, model string, fullPrompt string) string {
var outputBuf bytes.Buffer

if err := ensureModelRunning(); err != nil {
if Verbose {
fmt.Fprintf(os.Stderr, "[tree-ai] local model init failed: %v\n", err)
}
return defaultFallbackText(isDir)
}

modelLock.Lock()
defer modelLock.Unlock()

if Verbose {
fmt.Fprintf(os.Stderr, "[tree-ai] Sending prompt to local model for %s:\n%s\n", target, fullPrompt)
}

fmt.Fprintln(modelWriter, fullPrompt)
fmt.Fprintln(modelWriter, "<<END>>")

readStart := time.Now()
for {
if time.Since(readStart) > 5*time.Second {
if Verbose {
fmt.Fprintf(os.Stderr, "[tree-ai] Timeout reading model output for %s\n", target)
}
break
}
line, err := modelReader.ReadString('\n')
if err != nil && err != io.EOF {
if Verbose {
fmt.Fprintf(os.Stderr, "[tree-ai] Error reading response for %s: %v\n", target, err)
}
break
}
outputBuf.WriteString(line)
if strings.TrimSpace(line) == "<<END>>" || strings.HasSuffix(line, "\n\n") {
break
}
}

response := strings.TrimSpace(outputBuf.String())
if Verbose {
fmt.Fprintf(os.Stderr, "[tree-ai] Final raw model output for %s:\n%s\n", target, response)
}

if response != "" && response != "." {
return response
}
return defaultFallbackText(isDir)
}

func ensureModelRunning() error {
modelLock.Lock()
defer modelLock.Unlock()

if localModelProcess != nil {
return nil
}

fmt.Fprintln(os.Stderr, "🔄 Launching local AI model...")

exePath, err := os.Executable()
if err != nil {
return fmt.Errorf("could not get executable path: %w", err)
}
projectRoot := filepath.Dir(filepath.Dir(exePath))
pythonScriptPath := filepath.Join(projectRoot, "model", "granite_stream.py")
venvPythonPath := filepath.Join(projectRoot, ".venv", "bin", "python")

cmd := exec.Command(venvPythonPath, pythonScriptPath)
cmd.Dir = projectRoot
cmd.Env = append(os.Environ(), "PYTHONUNBUFFERED=1")

stdin, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("stdin pipe error: %w", err)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("stdout pipe error: %w", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("stderr pipe error: %w", err)
}

go func() { io.Copy(os.Stderr, stderr) }()

if err := cmd.Start(); err != nil {
return fmt.Errorf("start error: %w", err)
}

modelWriter = stdin
modelReader = bufio.NewReader(stdout)
localModelProcess = cmd

fmt.Fprintln(os.Stderr, "⏳ Waiting for model to become ready...")

line, err := modelReader.ReadString('\n')
if err != nil {
return fmt.Errorf("model failed to start: %v", err)
}

if !strings.Contains(line, "[READY]") {
return fmt.Errorf("unexpected output from model: %q", line)
}

fmt.Fprintln(os.Stderr, "✅ Local model ready.")
return nil
}

func defaultFallbackText(isDir bool) string {
if isDir {
return "Internal directory for project logic."
}
return "Internal project file."
}

func cleanModelResponse(rawText string, target string, isDir bool) string {
original := strings.TrimSpace(rawText)
text := original
Expand All @@ -129,6 +250,23 @@ func cleanModelResponse(rawText string, target string, isDir bool) string {
return text
}

func summarizeToOneLine(s string) string {
s = strings.ReplaceAll(s, "\n", " ")
s = strings.ReplaceAll(s, "\r", " ")
s = strings.Join(strings.Fields(s), " ")
s = strings.TrimSpace(s)

const maxLen = 120
if len(s) > maxLen {
s = s[:maxLen]
if i := strings.LastIndex(s, " "); i > 0 {
s = s[:i]
}
s += "..."
}
return s
}

func formatFinalResponse(label string, desc string, isDir bool) string {
arrow := "\033[38;5;208m➤\033[0m"
desc = strings.TrimSpace(desc)
Expand All @@ -138,18 +276,24 @@ func formatFinalResponse(label string, desc string, isDir bool) string {
return fmt.Sprintf("%s %s", arrow, desc)
}

func isBinary(path string) bool {
f, err := os.Open(path)
if err != nil {
return false
}
defer f.Close()

func fallback(target string, isDir bool, model string, fullPrompt string) string {
cmd := exec.Command(".venv/bin/python", "model/granite_infer.py", "--prompt", fullPrompt)
cmd.Env = append(os.Environ(), "TRANSFORMERS_CACHE=.hf-cache")
output, err := cmd.Output()
if err == nil && len(output) > 0 {
return string(output)
buf := make([]byte, 800)
n, _ := f.Read(buf)
if n == 0 {
return false
}
if isDir {
return "Internal directory for project logic."
for _, b := range buf[:n] {
if b == 0 {
return true
}
}
return "Internal project file."
return false
}

func isEndpointAvailable(url string) bool {
Expand Down Expand Up @@ -187,6 +331,7 @@ func collectContent(path string, isDir bool) string {
builder.WriteString("\n... [truncated]")
}
}

if !isDir {
addFile(path)
} else {
Expand All @@ -201,49 +346,10 @@ func collectContent(path string, isDir bool) string {
return nil
})
}

result := builder.String()
if len(result) > maxTotalBytes {
result = result[:maxTotalBytes] + "\n... [truncated]"
}
return result
}

func summarizeToOneLine(s string) string {
s = strings.ReplaceAll(s, "\n", " ")
s = strings.ReplaceAll(s, "\r", " ")
s = strings.Join(strings.Fields(s), " ") // collapse multiple spaces
s = strings.TrimSpace(s)

const maxLen = 120
if len(s) > maxLen {
s = s[:maxLen]
// Optionally trim mid-word if needed
if i := strings.LastIndex(s, " "); i > 0 {
s = s[:i]
}
s += "..."
}

return s
}


func isBinary(path string) bool {
f, err := os.Open(path)
if err != nil {
return false
}
defer f.Close()

buf := make([]byte, 800)
n, _ := f.Read(buf)
if n == 0 {
return false
}
for _, b := range buf[:n] {
if b == 0 {
return true
}
}
return false
}
Binary file added model/granite-runner
Binary file not shown.
Loading