forked from kubernetes-sigs/prow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: migrate code from forked
test-infra
repo
Signed-off-by: wuhuizuo <[email protected]>
- Loading branch information
Showing
45 changed files
with
3,711 additions
and
292 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
--- | ||
token: <openai-token> | ||
base_url: https://<api-url>/ | ||
api_type: AZURE | ||
api_version: 2023-03-15-preview | ||
engine: <deploy-engine-name> | ||
model: gpt-3.5-turbo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"fmt" | ||
"os" | ||
"path" | ||
"sync" | ||
"time" | ||
|
||
"github.com/sirupsen/logrus" | ||
"gopkg.in/yaml.v3" | ||
) | ||
|
||
// ConfigAgent agent for fetch tasks with watching and hot reload. | ||
type ConfigAgent[T any] struct { | ||
path string | ||
config T | ||
mu sync.RWMutex | ||
} | ||
|
||
// WatchConfig monitors a file for changes and sends a message on the channel when the file changes | ||
func (c *ConfigAgent[T]) WatchConfig(ctx context.Context, interval time.Duration, onChangeHandler func(f string) error) { | ||
var lastMod time.Time | ||
ticker := time.NewTicker(interval) | ||
defer ticker.Stop() | ||
|
||
for { | ||
select { | ||
case <-ctx.Done(): | ||
return | ||
case <-ticker.C: | ||
logrus.Debug("ticker") | ||
info, err := os.Stat(c.path) | ||
if err != nil { | ||
fmt.Printf("Error getting file info: %v\n", err) | ||
} else if modTime := info.ModTime(); modTime.After(lastMod) { | ||
lastMod = modTime | ||
onChangeHandler(c.path) | ||
} | ||
} | ||
} | ||
} | ||
|
||
// Reload read and update config data. | ||
func (c *ConfigAgent[T]) Reload(file string, optionFuncs ...func() error) error { | ||
data, err := os.ReadFile(file) | ||
if err != nil { | ||
return fmt.Errorf("could no load config file %s: %w", file, err) | ||
} | ||
|
||
c.mu.Lock() | ||
defer c.mu.Unlock() | ||
switch path.Ext(file) { | ||
case ".json": | ||
if err := json.Unmarshal(data, &c.config); err != nil { | ||
return fmt.Errorf("could not unmarshal JSON config: %w", err) | ||
} | ||
case ".yaml", ".yml": | ||
if err := yaml.Unmarshal(data, &c.config); err != nil { | ||
return fmt.Errorf("could not unmarshal YAML config: %w", err) | ||
} | ||
default: | ||
return errors.New("only support file with `.json` or `.yaml` extension") | ||
} | ||
|
||
for _, f := range optionFuncs { | ||
if err := f(); err != nil { | ||
return err | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (c *ConfigAgent[T]) Data() T { | ||
c.mu.RLock() | ||
defer c.mu.RUnlock() | ||
|
||
return c.config | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"time" | ||
|
||
"github.com/sashabaranov/go-openai" | ||
) | ||
|
||
type OpenaiConfig struct { | ||
Token string `yaml:"token,omitempty" json:"token,omitempty"` | ||
BaseURL string `yaml:"base_url,omitempty" json:"base_url,omitempty"` | ||
OrgID string `yaml:"org_id,omitempty" json:"org_id,omitempty"` | ||
APIType string `yaml:"api_type,omitempty" json:"api_type,omitempty"` // OPEN_AI | AZURE | AZURE_AD | ||
APIVersion string `yaml:"api_version,omitempty" json:"api_version,omitempty"` // 2023-03-15-preview, required when APIType is APITypeAzure or APITypeAzureAD | ||
Engine string `yaml:"engine,omitempty" json:"engine,omitempty"` // required when APIType is APITypeAzure or APITypeAzureAD, it's the deploy instance name. | ||
Model string `yaml:"model,omitempty" json:"model,omitempty"` // OpenAI models, list ref: https://github.com/sashabaranov/go-openai/blob/master/completion.go#L15-L38 | ||
|
||
client *openai.Client | ||
} | ||
|
||
func (cfg *OpenaiConfig) initClient() error { | ||
if cfg.client == nil { | ||
openaiCfg := openai.DefaultConfig(cfg.Token) | ||
openaiCfg.BaseURL = cfg.BaseURL | ||
openaiCfg.OrgID = cfg.OrgID | ||
openaiCfg.APIType = openai.APIType(cfg.APIType) | ||
openaiCfg.APIVersion = cfg.APIVersion | ||
openaiCfg.Engine = cfg.Engine | ||
|
||
cfg.client = openai.NewClientWithConfig(openaiCfg) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// OpenaiAgent agent for openai clients with watching and hot reload. | ||
type OpenaiAgent struct { | ||
ConfigAgent[OpenaiConfig] | ||
} | ||
|
||
// NewOpenaiAgent returns a new openai loader. | ||
func NewOpenaiAgent(path string, watchInterval time.Duration) (*OpenaiAgent, error) { | ||
c := &OpenaiAgent{ConfigAgent: ConfigAgent[OpenaiConfig]{path: path}} | ||
if err := c.Reload(path); err != nil { | ||
return nil, err | ||
} | ||
|
||
go c.WatchConfig(context.Background(), watchInterval, c.Reload) | ||
|
||
return c, nil | ||
} | ||
|
||
func (a *OpenaiAgent) Reload(file string) error { | ||
return a.ConfigAgent.Reload(file, a.config.initClient) | ||
} | ||
|
||
type OpenaiWrapAgent struct { | ||
small *OpenaiAgent | ||
large *OpenaiAgent | ||
largeDownThreshold int | ||
} | ||
|
||
// NewOpenaiAgent returns a new openai loader. | ||
func NewWrapOpenaiAgent(defaultPath, largePath string, largeDownThreshold int, watchInterval time.Duration) (*OpenaiWrapAgent, error) { | ||
d, err := NewOpenaiAgent(defaultPath, watchInterval) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
ret := &OpenaiWrapAgent{small: d, largeDownThreshold: largeDownThreshold} | ||
if largePath != "" { | ||
l, err := NewOpenaiAgent(largePath, watchInterval) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
ret.large = l | ||
} | ||
|
||
return ret, nil | ||
} | ||
|
||
func (a *OpenaiWrapAgent) ClientFor(msgLen int) (*openai.Client, string) { | ||
if a.large != nil && | ||
a.largeDownThreshold > 0 && | ||
msgLen > a.largeDownThreshold { | ||
return a.large.Data().client, a.large.Data().Model | ||
} | ||
|
||
return a.small.Data().client, a.small.Data().Model | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"regexp" | ||
"time" | ||
|
||
"github.com/sirupsen/logrus" | ||
) | ||
|
||
const ( | ||
defaultSystemMessage = "You are an experienced software developer. You will act as a reviewer for a GitHub Pull Request, and you should answer by markdown format." | ||
defaultPromte = "Please identify potential problems and give some fixing suggestions." | ||
defaultPrPatchIntroducePromte = "This is the diff for the pull request:" | ||
defaultMaxResponseTokens = 500 | ||
defaultTemperature = 0.7 | ||
defaultStaticOutHeadnote = `> **I have already done a preliminary review for you, and I hope to help you do a better job.** | ||
------ | ||
` | ||
) | ||
|
||
// TasksConfig represent the all tasks store for the plugin. | ||
// | ||
// layer: org|repo / task-name / task-config | ||
type TasksConfig map[string]RepoTasks | ||
|
||
// RepoTasks represent the tasks for a repo or ORG. | ||
type RepoTasks map[string]*Task | ||
|
||
// Task reprensent the config for AI task item. | ||
// | ||
// $SystemMessage | ||
// -------------- | ||
// < $Prompt | ||
// < Here are the serval context contents: | ||
// $ExternalContexts.each do | ||
// | ||
// < - format(it.PromptTpl, fetch(it.ResURL)) | ||
// | ||
// < $PatchIntroducerPrompt | ||
// < ```diff | ||
// < diff content | ||
// < ``` | ||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
// > <OutputStaticHeadNote> | ||
// > responses from AI server. | ||
// | ||
// TODO(wuhuizuo): using go template to comose the question. | ||
type Task struct { | ||
Description string `yaml:"description,omitempty" json:"description,omitempty"` | ||
SystemMessage string `yaml:"system_message,omitempty" json:"system_message,omitempty"` | ||
UserPrompt string `yaml:"user_prompt,omitempty" json:"user_prompt,omitempty"` | ||
PatchIntroducePrompt string `yaml:"patch_introduce_prompt,omitempty" json:"patch_introduce_prompt,omitempty"` | ||
OutputStaticHeadNote string `yaml:"output_static_head_note,omitempty" json:"output_static_head_note,omitempty"` | ||
MaxResponseTokens int `yaml:"max_response_tokens,omitempty" json:"max_response_tokens,omitempty"` | ||
ExternalContexts []*ExternalContext `yaml:"external_contexts,omitempty" json:"external_contexts,omitempty"` | ||
|
||
AlwaysRun bool `yaml:"always_run,omitempty" json:"always_run,omitempty"` // automatic run or should triggered by comments. | ||
SkipAuthors []string `yaml:"skip_authors,omitempty" json:"skip_authors,omitempty"` // skip the pull request created by the authors. | ||
SkipBrancheRegs []string `yaml:"skip_branche_regs,omitempty" json:"skip_branche_regs,omitempty"` // skip the pull requests whiches target branch matched the regex. | ||
SkipLabelRegs []string `yaml:"skip_label_regs,omitempty" json:"skip_label_regs,omitempty"` // skip the pull reqeusts when any labels matched on the pull request. | ||
|
||
skipBrancheRegs []*regexp.Regexp `yaml:"-" json:"-"` | ||
skipLabelRegs []*regexp.Regexp `yaml:"-" json:"-"` | ||
} | ||
|
||
func (t *Task) initRegexps() error { | ||
if len(t.SkipBrancheRegs) != 0 { | ||
for _, br := range t.SkipBrancheRegs { | ||
reg, err := regexp.Compile(br) | ||
if err != nil { | ||
return err | ||
} | ||
t.skipBrancheRegs = append(t.skipBrancheRegs, reg) | ||
} | ||
} | ||
|
||
if len(t.SkipLabelRegs) != 0 { | ||
for _, br := range t.SkipLabelRegs { | ||
reg, err := regexp.Compile(br) | ||
if err != nil { | ||
return err | ||
} | ||
t.skipLabelRegs = append(t.skipLabelRegs, reg) | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
type ExternalContext struct { | ||
PromptTpl string `yaml:"prompt_tpl,omitempty" json:"prompt_tpl,omitempty"` | ||
ResURL string `yaml:"res_url,omitempty" json:"res_url,omitempty"` | ||
resContent []byte //nolint: unused // to be implemented. | ||
} | ||
|
||
func (ec *ExternalContext) Content() ([]byte, error) { | ||
return nil, errors.New("not implemented") | ||
} | ||
|
||
// TaskAgent agent for fetch tasks with watching and hot reload. | ||
type TaskAgent struct { | ||
ConfigAgent[TasksConfig] | ||
} | ||
|
||
// NewTaskAgent returns a new ConfigLoader. | ||
func NewTaskAgent(path string, watchInterval time.Duration) (*TaskAgent, error) { | ||
c := &TaskAgent{ConfigAgent: ConfigAgent[TasksConfig]{path: path}} | ||
if err := c.Reload(path); err != nil { | ||
return nil, err | ||
} | ||
|
||
go c.WatchConfig(context.Background(), watchInterval, c.Reload) | ||
|
||
return c, nil | ||
} | ||
|
||
func (a *TaskAgent) Reload(file string) error { | ||
return a.ConfigAgent.Reload(file, func() error { | ||
for _, repoCfg := range a.config { | ||
for _, task := range repoCfg { | ||
if err := task.initRegexps(); err != nil { | ||
return err | ||
} | ||
} | ||
} | ||
return nil | ||
}) | ||
} | ||
|
||
// Get return the config data. | ||
func (c *TaskAgent) TasksFor(org, repo string) (map[string]*Task, error) { | ||
c.mu.RLock() | ||
defer c.mu.RUnlock() | ||
|
||
fullName := fmt.Sprintf("%s/%s", org, repo) | ||
repoTasks, ok := c.Data()[fullName] | ||
if ok { | ||
return repoTasks, nil | ||
} | ||
logrus.Debugf("no tasks for repo: %s", fullName) | ||
|
||
orgTasks, ok := c.config[org] | ||
if ok { | ||
return orgTasks, nil | ||
} | ||
logrus.Debugf("no tasks for org %s", org) | ||
logrus.Debugf("all tasks: %#v", c.config) | ||
return nil, nil | ||
} | ||
|
||
// Task return the given task config. | ||
func (c *TaskAgent) Task(org, repo, taskName string, needDefault bool) (*Task, error) { | ||
tasks, err := c.TasksFor(org, repo) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
task := tasks[taskName] | ||
if task != nil { | ||
return task, nil | ||
} | ||
|
||
if needDefault { | ||
return c.DefaultTask(), nil | ||
} | ||
|
||
return nil, nil | ||
} | ||
|
||
func (c *TaskAgent) DefaultTask() *Task { | ||
return &Task{ | ||
SystemMessage: defaultSystemMessage, | ||
UserPrompt: defaultPromte, | ||
MaxResponseTokens: defaultMaxResponseTokens, | ||
PatchIntroducePrompt: defaultPrPatchIntroducePromte, | ||
} | ||
} |
Oops, something went wrong.