From 20d662101868309ffb0fb31073e36a8263780638 Mon Sep 17 00:00:00 2001 From: Anthony DePasquale Date: Wed, 11 Dec 2024 21:00:51 +0100 Subject: [PATCH] Add functionality from TypeScript implementation --- Sources/Ast.swift | 30 +- Sources/Environment.swift | 204 ++++--- Sources/Lexer.swift | 6 +- Sources/Parser.swift | 195 ++++--- Sources/Runtime.swift | 1015 ++++++++++++++++++++++++++------- Sources/Utilities.swift | 96 ++++ Tests/ChatTemplateTests.swift | 434 +++++++++++++- Tests/InterpreterTests.swift | 11 +- Tests/ToolUseTests.swift | 397 +++++++++++++ 9 files changed, 2007 insertions(+), 381 deletions(-) create mode 100644 Tests/ToolUseTests.swift diff --git a/Sources/Ast.swift b/Sources/Ast.swift index 7460284..8f2a60d 100644 --- a/Sources/Ast.swift +++ b/Sources/Ast.swift @@ -41,7 +41,7 @@ struct TupleLiteral: Literal { } struct ObjectLiteral: Literal { - var value: [(Expression, Expression)] + var value: [String: Expression] } struct Set: Statement { @@ -49,7 +49,7 @@ struct Set: Statement { var value: Expression } -struct If: Statement { +struct If: Statement, Expression { var test: Expression var body: [Statement] var alternate: [Statement] @@ -59,14 +59,14 @@ struct Identifier: Expression { var value: String } -protocol Loopvar {} -extension Identifier: Loopvar {} -extension TupleLiteral: Loopvar {} +typealias Loopvar = Expression struct For: Statement { var loopvar: Loopvar var iterable: Expression var body: [Statement] + var defaultBlock: [Statement] + var ifCondition: Expression? } struct MemberExpression: Expression { @@ -124,3 +124,23 @@ struct KeywordArgumentExpression: Expression { struct NullLiteral: Literal { var value: Any? = nil } + +struct SelectExpression: Expression { + var iterable: Expression + var test: Expression +} + +struct Macro: Statement { + var name: Identifier + var args: [Expression] + var body: [Statement] +} + +struct KeywordArgumentsValue: RuntimeValue { + var value: [String: any RuntimeValue] + var builtins: [String: any RuntimeValue] = [:] + + func bool() -> Bool { + !value.isEmpty + } +} diff --git a/Sources/Environment.swift b/Sources/Environment.swift index c845068..555b4cd 100644 --- a/Sources/Environment.swift +++ b/Sources/Environment.swift @@ -12,42 +12,39 @@ class Environment { var variables: [String: any RuntimeValue] = [ "namespace": FunctionValue(value: { args, _ in - if args.count == 0 { + if args.isEmpty { return ObjectValue(value: [:]) } - if args.count != 1 || !(args[0] is ObjectValue) { + guard args.count == 1, let objectArg = args[0] as? ObjectValue else { throw JinjaError.runtime("`namespace` expects either zero arguments or a single object argument") } - return args[0] + return objectArg }) ] var tests: [String: (any RuntimeValue...) throws -> Bool] = [ - "boolean": { - args in - args[0] is BooleanValue + "boolean": { args in + return args[0] is BooleanValue }, - "callable": { - args in - args[0] is FunctionValue + "callable": { args in + return args[0] is FunctionValue }, - "odd": { - args in - if let arg = args.first as? NumericValue { - return arg.value as! Int % 2 != 0 + "odd": { args in + if let arg = args.first as? NumericValue, let intValue = arg.value as? Int { + return intValue % 2 != 0 } else { - throw JinjaError.runtime("Cannot apply test 'odd' to type: \(type(of:args.first))") + throw JinjaError.runtime("Cannot apply test 'odd' to type: \(type(of: args.first))") } }, "even": { args in - if let arg = args.first as? NumericValue { - return arg.value as! Int % 2 == 0 + if let arg = args.first as? NumericValue, let intValue = arg.value as? Int { + return intValue % 2 == 0 } else { - throw JinjaError.runtime("Cannot apply test 'even' to type: \(type(of:args.first))") + throw JinjaError.runtime("Cannot apply test 'even' to type: \(type(of: args.first))") } }, "false": { args in @@ -62,24 +59,28 @@ class Environment { } return false }, + "string": { args in + return args[0] is StringValue + }, "number": { args in - args[0] is NumericValue + return args[0] is NumericValue }, "integer": { args in if let arg = args[0] as? NumericValue { return arg.value is Int } - return false }, + "mapping": { args in + return args[0] is ObjectValue + }, "iterable": { args in - args[0] is ArrayValue || args[0] is StringValue + return args[0] is ArrayValue || args[0] is StringValue || args[0] is ObjectValue }, "lower": { args in if let arg = args[0] as? StringValue { return arg.value == arg.value.lowercased() } - return false }, "upper": { args in @@ -89,16 +90,47 @@ class Environment { return false }, "none": { args in - args[0] is NullValue + return args[0] is NullValue }, "defined": { args in - !(args[0] is UndefinedValue) + return !(args[0] is UndefinedValue) }, "undefined": { args in - args[0] is UndefinedValue + return args[0] is UndefinedValue }, - "equalto": { _ in - throw JinjaError.syntaxNotSupported("equalto") + "equalto": { args in + if args.count == 2 { + if let left = args[0] as? StringValue, let right = args[1] as? StringValue { + return left.value == right.value + } else if let left = args[0] as? NumericValue, let right = args[1] as? NumericValue, + let leftInt = left.value as? Int, let rightInt = right.value as? Int + { + return leftInt == rightInt + } else if let left = args[0] as? BooleanValue, let right = args[1] as? BooleanValue { + return left.value == right.value + } else { + return false + } + } else { + return false + } + }, + "eq": { args in + if args.count == 2 { + if let left = args[0] as? StringValue, let right = args[1] as? StringValue { + return left.value == right.value + } else if let left = args[0] as? NumericValue, let right = args[1] as? NumericValue, + let leftInt = left.value as? Int, let rightInt = right.value as? Int + { + return leftInt == rightInt + } else if let left = args[0] as? BooleanValue, let right = args[1] as? BooleanValue { + return left.value == right.value + } else { + return false + } + } else { + return false + } }, ] @@ -107,66 +139,111 @@ class Environment { } func isFunction(_ value: Any, functionType: T.Type) -> Bool { - value is T + return value is T } - func convertToRuntimeValues(input: Any) throws -> any RuntimeValue { + func convertToRuntimeValues(input: Any?) throws -> any RuntimeValue { + if input == nil { + return NullValue() + } switch input { case let value as Bool: return BooleanValue(value: value) - case let values as [any Numeric]: - var items: [any RuntimeValue] = [] - for value in values { - try items.append(self.convertToRuntimeValues(input: value)) - } - return ArrayValue(value: items) case let value as any Numeric: return NumericValue(value: value) case let value as String: return StringValue(value: value) + case let data as Data: + guard let string = String(data: data, encoding: .utf8) else { + throw JinjaError.runtime("Failed to convert data to string") + } + return StringValue(value: string) case let fn as (String) throws -> Void: return FunctionValue { args, _ in - var arg = "" - switch args[0].value { - case let value as String: - arg = value - case let value as Bool: - arg = String(value) - default: - throw JinjaError.runtime("Unknown arg type:\(type(of: args[0].value))") + guard let stringArg = args[0] as? StringValue else { + throw JinjaError.runtime("Argument must be a StringValue") } - - try fn(arg) + try fn(stringArg.value) return NullValue() } case let fn as (Bool) throws -> Void: return FunctionValue { args, _ in - try fn(args[0].value as! Bool) + guard let boolArg = args[0] as? BooleanValue else { + throw JinjaError.runtime("Argument must be a BooleanValue") + } + try fn(boolArg.value) return NullValue() } case let fn as (Int, Int?, Int) -> [Int]: return FunctionValue { args, _ in - let result = fn(args[0].value as! Int, args[1].value as? Int, args[2].value as! Int) + guard args.count > 0, let arg0 = args[0] as? NumericValue, let int0 = arg0.value as? Int else { + throw JinjaError.runtime("First argument must be an Int") + } + var int1: Int? = nil + if args.count > 1 { + if let numericValue = args[1] as? NumericValue, let tempInt1 = numericValue.value as? Int { + int1 = tempInt1 + } else { + throw JinjaError.runtime("Second argument must be an Int or nil") + } + } + var int2: Int = 1 + if args.count > 2 { + if let numericValue = args[2] as? NumericValue, let tempInt2 = numericValue.value as? Int { + int2 = tempInt2 + } else { + throw JinjaError.runtime("Third argument must be an Int") + } + } + let result = fn(int0, int1, int2) return try self.convertToRuntimeValues(input: result) } - case let values as [Any]: - var items: [any RuntimeValue] = [] - for value in values { - try items.append(self.convertToRuntimeValues(input: value)) + case let fn as ([Int]) -> [Int]: + return FunctionValue { args, _ in + let intArgs = args.compactMap { ($0 as? NumericValue)?.value as? Int } + guard intArgs.count == args.count else { + throw JinjaError.runtime("Arguments to range must be Ints") + } + let result = fn(intArgs) + return try self.convertToRuntimeValues(input: result) + } + case let fn as (Int, Int?, Int) -> [Int]: + return FunctionValue { args, _ in + guard let arg0 = args[0] as? NumericValue, let int0 = arg0.value as? Int else { + throw JinjaError.runtime("First argument must be an Int") + } + let int1 = (args.count > 1) ? (args[1] as? NumericValue)?.value as? Int : nil + guard let arg2 = args.last as? NumericValue, let int2 = arg2.value as? Int else { + throw JinjaError.runtime("Last argument must be an Int") + } + let result = fn(int0, int1, int2) + return try self.convertToRuntimeValues(input: result) } + case let values as [Any]: + let items = try values.map { try self.convertToRuntimeValues(input: $0) } return ArrayValue(value: items) - case let dictionary as [String: String]: + case let dictionary as [String: Any?]: + // Create ordered pairs from the dictionary, maintaining original order + let orderedPairs = Array(dictionary) var object: [String: any RuntimeValue] = [:] + var keyOrder: [String] = [] - for (key, value) in dictionary { - object[key] = StringValue(value: value) + // Convert values while maintaining order + for (key, value) in orderedPairs { + // Convert nil values to NullValue + object[key] = try self.convertToRuntimeValues(input: value) + keyOrder.append(key) } - return ObjectValue(value: object) + // Use the original order from orderedPairs for keyOrder + return ObjectValue(value: object, keyOrder: keyOrder) + case is NullValue: return NullValue() default: - throw JinjaError.runtime("Cannot convert to runtime value: \(input) type:\(type(of: input))") + throw JinjaError.runtime( + "Cannot convert to runtime value: \(String(describing: input)) type:\(type(of: input))" + ) } } @@ -176,12 +253,11 @@ class Environment { } func declareVariable(name: String, value: any RuntimeValue) throws -> any RuntimeValue { - if self.variables.contains(where: { $0.0 == name }) { + if self.variables.keys.contains(name) { throw JinjaError.syntax("Variable already declared: \(name)") } self.variables[name] = value - return value } @@ -191,13 +267,13 @@ class Environment { return value } - func resolve(name: String) throws -> Self { - if self.variables.contains(where: { $0.0 == name }) { + func resolve(name: String) throws -> Environment { + if self.variables.keys.contains(name) { return self } - if let parent { - return try parent.resolve(name: name) as! Self + if let parent = self.parent { + return try parent.resolve(name: name) } throw JinjaError.runtime("Unknown variable: \(name)") @@ -205,11 +281,7 @@ class Environment { func lookupVariable(name: String) -> any RuntimeValue { do { - if let value = try self.resolve(name: name).variables[name] { - return value - } else { - return UndefinedValue() - } + return try self.resolve(name: name).variables[name] ?? UndefinedValue() } catch { return UndefinedValue() } diff --git a/Sources/Lexer.swift b/Sources/Lexer.swift index 1093960..f6473e9 100644 --- a/Sources/Lexer.swift +++ b/Sources/Lexer.swift @@ -50,6 +50,8 @@ enum TokenType: String { case and = "And" case or = "Or" case not = "Not" + case macro = "Macro" + case endMacro = "EndMacro" } struct Token: Equatable { @@ -70,6 +72,8 @@ let keywords: [String: TokenType] = [ "and": .and, "or": .or, "not": .not, + "macro": .macro, + "endmacro": .endMacro, // Literals "true": .booleanLiteral, "false": .booleanLiteral, @@ -81,7 +85,7 @@ func isWord(char: String) -> Bool { } func isInteger(char: String) -> Bool { - char.range(of: #"[0-9]"#, options: .regularExpression) != nil + char.range(of: #"^[0-9]+$"#, options: .regularExpression) != nil } func isWhile(char: String) -> Bool { diff --git a/Sources/Parser.swift b/Sources/Parser.swift index 648a025..83aad16 100644 --- a/Sources/Parser.swift +++ b/Sources/Parser.swift @@ -22,7 +22,7 @@ func parse(tokens: [Token]) throws -> Program { return prev } - func parseArgumentsList() throws -> [Statement] { + func parseArgumentsList() throws -> [Expression] { var args: [Expression] = [] while !typeof(.closeParen) { @@ -33,13 +33,13 @@ func parse(tokens: [Token]) throws -> Program { if let identifier = argument as? Identifier { let value = try parseExpression() - argument = KeywordArgumentExpression(key: identifier, value: value as! Expression) + argument = KeywordArgumentExpression(key: identifier, value: value) } else { throw JinjaError.syntax("Expected identifier for keyword argument") } } - args.append(argument as! Expression) + args.append(argument) if typeof(.comma) { current += 1 @@ -49,7 +49,7 @@ func parse(tokens: [Token]) throws -> Program { return args } - func parseArgs() throws -> [Statement] { + func parseArgs() throws -> [Expression] { try expect(type: .openParen, error: "Expected opening parenthesis for arguments list") let args = try parseArgumentsList() @@ -63,14 +63,10 @@ func parse(tokens: [Token]) throws -> Program { try StringLiteral(value: expect(type: .text, error: "Expected text token").value) } - func parseCallExpression(callee: Statement) throws -> CallExpression { - var args: [Expression] = [] - - for arg in try parseArgs() { - args.append(arg as! Expression) - } + func parseCallExpression(callee: Expression) throws -> CallExpression { + let args = try parseArgs() - var callExpression = CallExpression(callee: callee as! Expression, args: args) + var callExpression = CallExpression(callee: callee, args: args) if typeof(.openParen) { callExpression = try parseCallExpression(callee: callExpression) @@ -79,8 +75,8 @@ func parse(tokens: [Token]) throws -> Program { return callExpression } - func parseMemberExpressionArgumentsList() throws -> Statement { - var slices: [Statement?] = [] + func parseMemberExpressionArgumentsList() throws -> Expression { + var slices: [Expression?] = [] var isSlice = false while !typeof(.closeSquareBracket) { @@ -89,7 +85,7 @@ func parse(tokens: [Token]) throws -> Program { current += 1 isSlice = true } else { - try slices.append(parseExpression()) + slices.append(try parseExpression()) if typeof(.colon) { current += 1 isSlice = true @@ -107,22 +103,22 @@ func parse(tokens: [Token]) throws -> Program { } return SliceExpression( - start: slices[0] as? Expression, - stop: slices.count > 1 ? slices[1] as? Expression : nil, - step: slices.count > 2 ? slices[2] as? Expression : nil + start: slices[0], + stop: slices.count > 1 ? slices[1] : nil, + step: slices.count > 2 ? slices[2] : nil ) } return slices[0]! } - func parseMemberExpression() throws -> Statement { + func parseMemberExpression() throws -> Expression { var object = try parsePrimaryExpression() while typeof(.dot) || typeof(.openSquareBracket) { let operation = tokens[current] current += 1 - var property: Statement + var property: Expression let computed = operation.type != .dot @@ -137,8 +133,8 @@ func parse(tokens: [Token]) throws -> Program { } object = MemberExpression( - object: object as! Expression, - property: property as! Expression, + object: object, + property: property, computed: computed ) } @@ -146,7 +142,7 @@ func parse(tokens: [Token]) throws -> Program { return object } - func parseCallMemberExpression() throws -> Statement { + func parseCallMemberExpression() throws -> Expression { let member = try parseMemberExpression() if typeof(.openParen) { @@ -156,7 +152,7 @@ func parse(tokens: [Token]) throws -> Program { return member } - func parseFilterExpression() throws -> Statement { + func parseFilterExpression() throws -> Expression { var operand = try parseCallMemberExpression() while typeof(.pipe) { @@ -171,14 +167,14 @@ func parse(tokens: [Token]) throws -> Program { } if let filter = filter as? Filter { - operand = FilterExpression(operand: operand as! Expression, filter: filter) + operand = FilterExpression(operand: operand, filter: filter) } } return operand } - func parseTestExpression() throws -> Statement { + func parseTestExpression() throws -> Expression { var operand = try parseFilterExpression() while typeof(.is) { @@ -194,7 +190,7 @@ func parse(tokens: [Token]) throws -> Program { filter = Identifier(value: "none") } if let test = filter as? Identifier { - operand = TestExpression(operand: operand as! Expression, negate: negate, test: test) + operand = TestExpression(operand: operand, negate: negate, test: test) } else { throw JinjaError.syntax("Expected identifier for the test") } @@ -202,49 +198,49 @@ func parse(tokens: [Token]) throws -> Program { return operand } - func parseMultiplicativeExpression() throws -> Statement { + func parseMultiplicativeExpression() throws -> Expression { var left = try parseTestExpression() while typeof(.multiplicativeBinaryOperator) { let operation = tokens[current] current += 1 let right = try parseTestExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseAdditiveExpression() throws -> Statement { + func parseAdditiveExpression() throws -> Expression { var left = try parseMultiplicativeExpression() while typeof(.additiveBinaryOperator) { let operation = tokens[current] current += 1 let right = try parseMultiplicativeExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseComparisonExpression() throws -> Statement { + func parseComparisonExpression() throws -> Expression { var left = try parseAdditiveExpression() while typeof(.comparisonBinaryOperator) || typeof(.in) || typeof(.notIn) { let operation = tokens[current] current += 1 let right = try parseAdditiveExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseLogicalNegationExpression() throws -> Statement { + func parseLogicalNegationExpression() throws -> Expression { var right: UnaryExpression? while typeof(.not) { let operation = tokens[current] current += 1 let argument = try parseLogicalNegationExpression() - right = UnaryExpression(operation: operation, argument: argument as! Expression) + right = UnaryExpression(operation: operation, argument: argument) } if let right { @@ -254,44 +250,52 @@ func parse(tokens: [Token]) throws -> Program { } } - func parseLogicalAndExpression() throws -> Statement { + func parseLogicalAndExpression() throws -> Expression { var left = try parseLogicalNegationExpression() while typeof(.and) { let operation = tokens[current] current += 1 let right = try parseLogicalNegationExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseLogicalOrExpression() throws -> Statement { + func parseLogicalOrExpression() throws -> Expression { var left = try parseLogicalAndExpression() while typeof(.or) { let operation = tokens[current] current += 1 let right = try parseLogicalAndExpression() - left = BinaryExpression(operation: operation, left: left as! Expression, right: right as! Expression) + left = BinaryExpression(operation: operation, left: left, right: right) } return left } - func parseTernaryExpression() throws -> Statement { + func parseTernaryExpression() throws -> Expression { let a = try parseLogicalOrExpression() if typeof(.if) { - current += 1 - let test = try parseLogicalOrExpression() - try expect(type: .else, error: "Expected else token") - let b = try parseLogicalOrExpression() - return If(test: test as! Expression, body: [a], alternate: [b]) + // Ternary expression + current += 1 // consume if + let predicate = try parseLogicalOrExpression() + + if typeof(.else) { + // Ternary expression with else + current += 1 // consume else + let b = try parseLogicalOrExpression() + return If(test: predicate, body: [a], alternate: [b]) + } else { + // Select expression on iterable + return SelectExpression(iterable: a, test: predicate) + } } return a } - func parseExpression() throws -> Statement { + func parseExpression() throws -> Expression { try parseTernaryExpression() } @@ -314,9 +318,11 @@ func parse(tokens: [Token]) throws -> Program { if typeof(.equals) { current += 1 - let value = try parseSetStatement() + // Parse the right-hand side as an expression + let value = try parseExpression() - return Set(assignee: left as! Expression, value: value as! Expression) + // Explicitly cast 'value' to 'Expression' + return Set(assignee: left, value: value) } return left @@ -334,31 +340,37 @@ func parse(tokens: [Token]) throws -> Program { && (tokens[current + 1].type == .elseIf || tokens[current + 1].type == .else || tokens[current + 1].type == .endIf)) { - try body.append(parseAny()) + body.append(try parseAny()) } if tokens[current].type == .openStatement, tokens[current + 1].type != .endIf { current += 1 if typeof(.elseIf) { try expect(type: .elseIf, error: "Expected elseif token") - try alternate.append(parseIfStatement()) + alternate.append(try parseIfStatement()) } else { try expect(type: .else, error: "Expected else token") try expect(type: .closeStatement, error: "Expected closing statement token") while !(tokens[current].type == .openStatement && tokens[current + 1].type == .endIf) { - try alternate.append(parseAny()) + alternate.append(try parseAny()) } } } - return If(test: test as! Expression, body: body, alternate: alternate) + return If(test: test, body: body, alternate: alternate) } - func parsePrimaryExpression() throws -> Statement { + func parsePrimaryExpression() throws -> Expression { let token = tokens[current] switch token.type { case .numericLiteral: current += 1 - return NumericLiteral(value: Int(token.value) ?? 0) + if let intValue = Int(token.value) { + return NumericLiteral(value: intValue) + } else if let doubleValue = Double(token.value) { + return NumericLiteral(value: doubleValue) + } else { + throw JinjaError.parser("Invalid numeric literal: \(token.value)") + } case .stringLiteral: current += 1 return StringLiteral(value: token.value) @@ -383,7 +395,7 @@ func parse(tokens: [Token]) throws -> Program { current += 1 var values: [Expression] = [] while !typeof(.closeSquareBracket) { - try values.append(parseExpression() as! Expression) + try values.append(parseExpression()) if typeof(.comma) { current += 1 } @@ -392,12 +404,20 @@ func parse(tokens: [Token]) throws -> Program { return ArrayLiteral(value: values) case .openCurlyBracket: current += 1 - var values: [(Expression, Expression)] = [] + var values: [String: Expression] = [:] while !typeof(.closeCurlyBracket) { let key = try parseExpression() try expect(type: .colon, error: "Expected colon between key and value in object literal") let value = try parseExpression() - values.append((key as! Expression, value as! Expression)) + + if let key = key as? StringLiteral { + values[key.value] = value + } else if let key = key as? Identifier { + values[key.value] = value + } else { + throw JinjaError.syntax("Expected string literal or identifier as key in object literal") + } + if typeof(.comma) { current += 1 } @@ -409,13 +429,13 @@ func parse(tokens: [Token]) throws -> Program { } } - func parseExpressionSequence(primary: Bool = false) throws -> Statement { + func parseExpressionSequence(primary: Bool = false) throws -> Expression { let fn = primary ? parsePrimaryExpression : parseExpression - var expressions: [Expression] = try [fn() as! Expression] + var expressions: [Expression] = try [fn()] let isTuple = typeof(.comma) while isTuple { current += 1 - try expressions.append(fn() as! Expression) + try expressions.append(fn()) if !typeof(.comma) { break } @@ -437,9 +457,10 @@ func parse(tokens: [Token]) throws -> Program { func parseForStatement() throws -> Statement { let loopVariable = try parseExpressionSequence(primary: true) + // Check if the loop variable is an Identifier or TupleLiteral if !(loopVariable is Identifier || loopVariable is TupleLiteral) { throw JinjaError.syntax( - "Expected identifier/tuple for the loop variable, got \(type(of:loopVariable)) instead" + "Expected identifier or tuple for the loop variable, got \(type(of: loopVariable)) instead" ) } @@ -447,22 +468,58 @@ func parse(tokens: [Token]) throws -> Program { let iterable = try parseExpression() + // Handle optional if condition + var ifCondition: Expression? = nil + if typeof(.if) { + current += 1 // Consume 'if' token + ifCondition = try parseExpression() + } + try expect(type: .closeStatement, error: "Expected closing statement token") var body: [Statement] = [] - while not(.openStatement, .endFor) { - try body.append(parseAny()) + var defaultBlock: [Statement] = [] + + while not(.openStatement, .endFor) && not(.openStatement, .else) { + body.append(try parseAny()) } - if let loopVariable = loopVariable as? Loopvar { - return For(loopvar: loopVariable, iterable: iterable as! Expression, body: body) + if typeof(.openStatement, .else) { + current += 1 // Consume '{%' + try expect(type: .else, error: "Expected else token") + try expect(type: .closeStatement, error: "Expected closing statement token") + + while not(.openStatement, .endFor) { + try defaultBlock.append(parseAny()) + } } - throw JinjaError.syntax( - "Expected identifier/tuple for the loop variable, got \(type(of:loopVariable)) instead" + return For( + loopvar: loopVariable, + iterable: iterable, + body: body, + defaultBlock: defaultBlock, + ifCondition: ifCondition ) } + func parseMacroStatement() throws -> Macro { + let name = try parsePrimaryExpression() + if !(name is Identifier) { + throw JinjaError.syntax("Expected identifier following macro statement") + } + let args = try parseArgs() + try expect(type: .closeStatement, error: "Expected closing statement token") + + var body: [Statement] = [] + + while not(.openStatement, .endMacro) { + body.append(try parseAny()) + } + + return Macro(name: name as! Identifier, args: args, body: body) + } + func parseJinjaStatement() throws -> Statement { try expect(type: .openStatement, error: "Expected opening statement token") var result: Statement @@ -484,6 +541,12 @@ func parse(tokens: [Token]) throws -> Program { try expect(type: .openStatement, error: "Expected {% token") try expect(type: .endFor, error: "Expected endfor token") try expect(type: .closeStatement, error: "Expected %} token") + case .macro: + current += 1 + result = try parseMacroStatement() + try expect(type: .openStatement, error: "Expected {% token") + try expect(type: .endMacro, error: "Expected endmacro token") + try expect(type: .closeStatement, error: "Expected %} token") default: throw JinjaError.syntax("Unknown statement type: \(tokens[current].type)") } diff --git a/Sources/Runtime.swift b/Sources/Runtime.swift index 73a0d48..6d3b17a 100644 --- a/Sources/Runtime.swift +++ b/Sources/Runtime.swift @@ -8,9 +8,9 @@ import Foundation protocol RuntimeValue { - associatedtype T - var value: T { get set } + associatedtype ValueType + var value: ValueType { get } var builtins: [String: any RuntimeValue] { get set } func bool() -> Bool @@ -21,7 +21,12 @@ struct NumericValue: RuntimeValue { var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { - self.value as? Int != 0 + if let intValue = self.value as? Int { + return intValue != 0 + } else if let doubleValue = self.value as? Double { + return doubleValue != 0.0 + } + return false } } @@ -35,7 +40,7 @@ struct BooleanValue: RuntimeValue { } struct NullValue: RuntimeValue { - var value: (any RuntimeValue)? + let value: Any? = nil var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { @@ -44,7 +49,7 @@ struct NullValue: RuntimeValue { } struct UndefinedValue: RuntimeValue { - var value: (any RuntimeValue)? + let value: Any? = nil var builtins: [String: any RuntimeValue] = [:] func bool() -> Bool { @@ -69,51 +74,143 @@ struct ArrayValue: RuntimeValue { } struct TupleValue: RuntimeValue { - var value: ArrayValue + var value: [any RuntimeValue] var builtins: [String: any RuntimeValue] = [:] + init(value: [any RuntimeValue]) { + self.value = value + self.builtins["length"] = FunctionValue(value: { _, _ in + NumericValue(value: value.count) + }) + } + func bool() -> Bool { - self.value.bool() + !self.value.isEmpty + } +} + +struct OrderedDictionary { + private var _dictionary: [Key: Value] + private var _keys: [Key] + + var keys: [Key] { _keys } + var values: [Value] { _keys.map { _dictionary[$0]! } } + var count: Int { _keys.count } + var isEmpty: Bool { _keys.isEmpty } + + init(_ elements: [(Key, Value)] = []) { + _dictionary = [:] + _keys = [] + // Preserve exact order of elements as provided + elements.forEach { key, value in + _dictionary[key] = value + if !_keys.contains(key) { + _keys.append(key) + } + } + } + + subscript(key: Key) -> Value? { + get { _dictionary[key] } + set { + if newValue == nil { + _dictionary.removeValue(forKey: key) + _keys.removeAll { $0 == key } + } else { + if _dictionary[key] == nil { + _keys.append(key) + } + _dictionary[key] = newValue + } + } + } + + func map(_ transform: (Key, Value) throws -> T) rethrows -> [T] { + try _keys.map { key in + try transform(key, _dictionary[key]!) + } } } struct ObjectValue: RuntimeValue { - var value: [String: any RuntimeValue] - var builtins: [String: any RuntimeValue] = [:] + private var storage: OrderedDictionary + var builtins: [String: any RuntimeValue] + + var value: [String: any RuntimeValue] { Dictionary(storage.map { ($0, $1) }, uniquingKeysWith: { $1 }) } + var orderedKeys: [String] { storage.keys } + + init(value: [String: any RuntimeValue], keyOrder: [String]? = nil) { + // If keyOrder is provided, use it; otherwise, maintain the original order from the dictionary + let orderedKeys = keyOrder ?? Array(value.keys) + let orderedPairs = orderedKeys.compactMap { key in + value[key].map { (key, $0) } + } + + // Recursively create OrderedDictionary for nested objects + let processedPairs = orderedPairs.map { key, value -> (String, any RuntimeValue) in + if let objectValue = value as? ObjectValue { + // Already an ObjectValue, use it directly + return (key, objectValue) + } else if let dictValue = value.value as? [String: any RuntimeValue] { + // If the value contains a dictionary, convert it to ObjectValue + return (key, ObjectValue(value: dictValue)) + } + return (key, value) + } + + self.storage = OrderedDictionary(processedPairs) - init(value: [String: any RuntimeValue]) { - self.value = value self.builtins = [ "get": FunctionValue(value: { args, _ in - if let key = args[0] as? StringValue { - if let value = value.first(where: { $0.0 == key.value }) { - return value as! (any RuntimeValue) - } else if args.count > 1 { - return args[1] - } else { - return NullValue() - } - } else { - throw JinjaError.runtime("Object key must be a string: got \(type(of:args[0]))") + guard let key = args[0] as? StringValue else { + throw JinjaError.runtime("Object key must be a string: got \(type(of: args[0]))") + } + if let value = value[key.value] { + return value + } else if args.count > 1 { + return args[1] } + return NullValue() }), "items": FunctionValue(value: { _, _ in - var items: [ArrayValue] = [] - for (k, v) in value { - items.append( - ArrayValue(value: [ - StringValue(value: k), - v, - ]) - ) - } - return items as! (any RuntimeValue) + ArrayValue( + value: orderedPairs.map { key, value in + ArrayValue(value: [StringValue(value: key), value]) + } + ) }), ] } + mutating func setValue(key: String, value: any RuntimeValue) { + storage[key] = value + } + func bool() -> Bool { - !self.value.isEmpty + !storage.isEmpty + } +} + +extension ObjectValue { + func toJSON(indent: Int? = nil, depth: Int = 0) throws -> String { + let indentValue = indent != nil ? String(repeating: " ", count: indent!) : "" + let basePadding = indent != nil ? "\n" + String(repeating: indentValue, count: depth) : "" + let childrenPadding = indent != nil ? basePadding + indentValue : "" + + // Use orderedKeys to maintain insertion order + let pairs = try orderedKeys.map { key in + guard let value = value[key] else { + throw JinjaError.runtime("Missing value for key: \(key)") + } + let jsonValue = try Jinja.toJSON(value, indent: indent, depth: depth + 1) + return "\"\(key)\": \(jsonValue)" + } + + if indent != nil { + return "{\(childrenPadding)\(pairs.joined(separator: ",\(childrenPadding)"))\(basePadding)}" + } else { + return "{\(pairs.joined(separator: ", "))}" + } } } @@ -146,12 +243,18 @@ struct StringValue: RuntimeValue { }), "title": FunctionValue(value: { _, _ in - StringValue(value: value.capitalized) + StringValue(value: value.titleCase()) }), "length": FunctionValue(value: { _, _ in NumericValue(value: value.count) }), + "rstrip": FunctionValue(value: { _, _ in + StringValue(value: value.replacingOccurrences(of: "\\s+$", with: "", options: .regularExpression)) + }), + "lstrip": FunctionValue(value: { _, _ in + StringValue(value: value.replacingOccurrences(of: "^\\s+", with: "", options: .regularExpression)) + }), ] } @@ -177,17 +280,14 @@ struct Interpreter { let lastEvaluated = try self.evaluate(statement: statement, environment: environment) if !(lastEvaluated is NullValue), !(lastEvaluated is UndefinedValue) { - if let value = lastEvaluated.value as? String { - result += value + if let stringValue = lastEvaluated as? StringValue { + result += stringValue.value + } else if let numericValue = lastEvaluated as? NumericValue { + result += String(describing: numericValue.value) + } else if let booleanValue = lastEvaluated as? BooleanValue { + result += String(booleanValue.value) } else { - switch lastEvaluated.value { - case let value as Int: - result += String(value) - case let value as String: - result += value - default: - throw JinjaError.runtime("Unknown value type:\(type(of: lastEvaluated.value))") - } + throw JinjaError.runtime("Cannot convert to string: \(type(of: lastEvaluated))") } } } @@ -206,26 +306,30 @@ struct Interpreter { try environment.setVariable(name: variableName, value: rhs) } else if let member = node.assignee as? MemberExpression { let object = try self.evaluate(statement: member.object, environment: environment) + guard var objectValue = object as? ObjectValue else { + throw JinjaError.runtime("Cannot assign to member of non-object") + } + guard let property = member.property as? Identifier else { + throw JinjaError.runtime("Cannot assign to member with non-identifier property") + } - if var object = object as? ObjectValue { - if let property = member.property as? Identifier { - object.value[property.value] = rhs - } else { - throw JinjaError.runtime("Cannot assign to member with non-identifier property") - } + // Modify the copy + objectValue.setValue(key: property.value, value: rhs) + + // Update the environment with the modified copy + if let parentIdentifier = member.object as? Identifier { + try environment.setVariable(name: parentIdentifier.value, value: objectValue) } else { - throw JinjaError.runtime("Cannot assign to member of non-object") + throw JinjaError.runtime("Cannot assign to computed member expression") } } else { - throw JinjaError.runtime("Invalid assignee type: \(type(of: node.assignee))") + throw JinjaError.runtime("Invalid LHS inside assignment expression: \(node.assignee)") } - return NullValue() } func evaluateIf(node: If, environment: Environment) throws -> StringValue { let test = try self.evaluate(statement: node.test, environment: environment) - return try self.evaluateBlock(statements: test.bool() ? node.body : node.alternate, environment: environment) } @@ -233,66 +337,99 @@ struct Interpreter { environment.lookupVariable(name: node.value) } - func evaluateFor(node: For, environment: Environment) throws -> any RuntimeValue { + func evaluateFor(node: For, environment: Environment) throws -> StringValue { + // Scope for the for loop let scope = Environment(parent: environment) - let iterable = try self.evaluate(statement: node.iterable, environment: scope) - var result = "" - if let iterable = iterable as? ArrayValue { - for i in 0 ..< iterable.value.count { - let loop: [String: any RuntimeValue] = [ - "index": NumericValue(value: i + 1), - "index0": NumericValue(value: i), - "revindex": NumericValue(value: iterable.value.count - i), - "revindex0": NumericValue(value: iterable.value.count - i - 1), - "first": BooleanValue(value: i == 0), - "last": BooleanValue(value: i == iterable.value.count - 1), - "length": NumericValue(value: iterable.value.count), - "previtem": i > 0 ? iterable.value[i - 1] : UndefinedValue(), - "nextitem": i < iterable.value.count - 1 ? iterable.value[i + 1] : UndefinedValue(), - ] - - try scope.setVariable(name: "loop", value: ObjectValue(value: loop)) - - let current = iterable.value[i] - - if let identifier = node.loopvar as? Identifier { + let test: Expression? + let iterable: any RuntimeValue + if let selectExpression = node.iterable as? SelectExpression { + iterable = try self.evaluate(statement: selectExpression.iterable, environment: scope) + test = selectExpression.test + } else { + iterable = try self.evaluate(statement: node.iterable, environment: scope) + test = nil + } + + guard let arrayIterable = iterable as? ArrayValue else { + throw JinjaError.runtime("Expected iterable type in for loop: got \(type(of: iterable))") + } + + var items: [any RuntimeValue] = [] + var scopeUpdateFunctions: [(Environment) throws -> Void] = [] + + for current in arrayIterable.value { + let loopScope = Environment(parent: scope) + + var scopeUpdateFunction: (Environment) throws -> Void + if let identifier = node.loopvar as? Identifier { + scopeUpdateFunction = { scope in try scope.setVariable(name: identifier.value, value: current) - } else { + } + } else if let tupleLiteral = node.loopvar as? TupleLiteral { + guard let currentArray = current as? ArrayValue else { + throw JinjaError.runtime("Cannot unpack non-iterable type: \(type(of: current))") } - switch node.loopvar { - case let identifier as Identifier: - try scope.setVariable(name: identifier.value, value: current) - case let tupleLiteral as TupleLiteral: - if let current = current as? ArrayValue { - if tupleLiteral.value.count != current.value.count { - throw JinjaError.runtime( - "Too \(tupleLiteral.value.count > current.value.count ? "few" : "many") items to unpack" - ) - } + if tupleLiteral.value.count != currentArray.value.count { + throw JinjaError.runtime( + "Too \(tupleLiteral.value.count > currentArray.value.count ? "few" : "many") items to unpack" + ) + } - for j in 0 ..< tupleLiteral.value.count { - if let identifier = tupleLiteral.value[j] as? Identifier { - try scope.setVariable(name: identifier.value, value: current.value[j]) - } else { - throw JinjaError.runtime( - "Cannot unpack non-identifier type: \(type(of:tupleLiteral.value[j]))" - ) - } + scopeUpdateFunction = { scope in + for (i, value) in tupleLiteral.value.enumerated() { + guard let identifier = value as? Identifier else { + throw JinjaError.runtime("Cannot unpack non-identifier type: \(type(of: value))") } - } else { - throw JinjaError.runtime("Cannot unpack non-iterable type: \(type(of:current))") + try scope.setVariable(name: identifier.value, value: currentArray.value[i]) } - default: - throw JinjaError.syntaxNotSupported(String(describing: node.loopvar)) } + } else { + throw JinjaError.runtime("Invalid loop variable(s): \(type(of: node.loopvar))") + } - let evaluated = try self.evaluateBlock(statements: node.body, environment: scope) - result += evaluated.value + if let test = test { + try scopeUpdateFunction(loopScope) + let testValue = try self.evaluate(statement: test, environment: loopScope) + if !testValue.bool() { + continue + } } - } else { - throw JinjaError.runtime("Expected iterable type in for loop: got \(type(of:iterable))") + + items.append(current) + scopeUpdateFunctions.append(scopeUpdateFunction) + } + + var result = "" + var noIteration = true + + for i in 0 ..< items.count { + let loop: [String: any RuntimeValue] = [ + "index": NumericValue(value: i + 1), + "index0": NumericValue(value: i), + "revindex": NumericValue(value: items.count - i), + "revindex0": NumericValue(value: items.count - i - 1), + "first": BooleanValue(value: i == 0), + "last": BooleanValue(value: i == items.count - 1), + "length": NumericValue(value: items.count), + "previtem": i > 0 ? items[i - 1] : UndefinedValue(), + "nextitem": i < items.count - 1 ? items[i + 1] : UndefinedValue(), + ] + + try scope.setVariable(name: "loop", value: ObjectValue(value: loop)) + + try scopeUpdateFunctions[i](scope) + + let evaluated = try self.evaluateBlock(statements: node.body, environment: scope) + result += evaluated.value + + noIteration = false + } + + if noIteration { + let defaultEvaluated = try self.evaluateBlock(statements: node.defaultBlock, environment: scope) + result += defaultEvaluated.value } return StringValue(value: result) @@ -302,31 +439,72 @@ struct Interpreter { let left = try self.evaluate(statement: node.left, environment: environment) if node.operation.value == "and" { - return left.bool() ? try self.evaluate(statement: node.right, environment: environment) : left + if !left.bool() { + return left + } + let right = try self.evaluate(statement: node.right, environment: environment) + return right } else if node.operation.value == "or" { return left.bool() ? left : try self.evaluate(statement: node.right, environment: environment) } let right = try self.evaluate(statement: node.right, environment: environment) + // == if node.operation.value == "==" { - switch left.value { - case let value as String: - return BooleanValue(value: value == right.value as! String) - case let value as Int: - return BooleanValue(value: value == right.value as! Int) - case let value as Bool: - return BooleanValue(value: value == right.value as! Bool) - default: - throw JinjaError.runtime( - "Unknown left value type:\(type(of: left.value)), right value type:\(type(of: right.value))" - ) + if let left = left as? StringValue, let right = right as? StringValue { + return BooleanValue(value: left.value == right.value) + } else if let left = left as? NumericValue, let right = right as? NumericValue { + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt == rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble == rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) == rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble == Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for equality comparison") + } + } else if let left = left as? BooleanValue, let right = right as? BooleanValue { + return BooleanValue(value: left.value == right.value) + } else if left is NullValue, right is NullValue { + return BooleanValue(value: true) + } else if left is UndefinedValue, right is UndefinedValue { + return BooleanValue(value: true) + } else if type(of: left) == type(of: right) { + return BooleanValue(value: false) + } else { + return BooleanValue(value: false) } - } else if node.operation.value == "!=" { - if type(of: left) != type(of: right) { + } + + // != + if node.operation.value == "!=" { + if let left = left as? StringValue, let right = right as? StringValue { + return BooleanValue(value: left.value != right.value) + } else if let left = left as? NumericValue, let right = right as? NumericValue { + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt != rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble != rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) != rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble != Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for inequality comparison") + } + } else if let left = left as? BooleanValue, let right = right as? BooleanValue { + return BooleanValue(value: left.value != right.value) + } else if left is NullValue, right is NullValue { + return BooleanValue(value: false) + } else if left is UndefinedValue, right is UndefinedValue { + return BooleanValue(value: false) + } else if type(of: left) == type(of: right) { return BooleanValue(value: true) } else { - return BooleanValue(value: left.value as! AnyHashable != right.value as! AnyHashable) + return BooleanValue(value: true) } } @@ -336,92 +514,230 @@ struct Interpreter { throw JinjaError.runtime("Cannot perform operation on null values") } else if let left = left as? NumericValue, let right = right as? NumericValue { switch node.operation.value { - case "+": throw JinjaError.syntaxNotSupported("+") - case "-": throw JinjaError.syntaxNotSupported("-") - case "*": throw JinjaError.syntaxNotSupported("*") - case "/": throw JinjaError.syntaxNotSupported("/") + case "+": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt + rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble + rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) + rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble + Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for addition") + } + case "-": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt - rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble - rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) - rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble - Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for subtraction") + } + case "*": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt * rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble * rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) * rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble * Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for multiplication") + } + case "/": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt / rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return NumericValue(value: leftDouble / rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return NumericValue(value: Double(leftInt) / rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return NumericValue(value: leftDouble / Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for division") + } case "%": - switch left.value { - case is Int: - return NumericValue(value: left.value as! Int % (right.value as! Int)) - default: - throw JinjaError.runtime("Unknown value type:\(type(of: left.value))") + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return NumericValue(value: leftInt % rightInt) + } else { + throw JinjaError.runtime("Unsupported numeric types for modulus") + } + case "<": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt < rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble < rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) < rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble < Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for less than comparison") + } + case ">": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt > rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble > rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) > rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble > Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for greater than comparison") + } + case ">=": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt >= rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble >= rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) >= rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble >= Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for greater than or equal to comparison") + } + case "<=": + if let leftInt = left.value as? Int, let rightInt = right.value as? Int { + return BooleanValue(value: leftInt <= rightInt) + } else if let leftDouble = left.value as? Double, let rightDouble = right.value as? Double { + return BooleanValue(value: leftDouble <= rightDouble) + } else if let leftInt = left.value as? Int, let rightDouble = right.value as? Double { + return BooleanValue(value: Double(leftInt) <= rightDouble) + } else if let leftDouble = left.value as? Double, let rightInt = right.value as? Int { + return BooleanValue(value: leftDouble <= Double(rightInt)) + } else { + throw JinjaError.runtime("Unsupported numeric types for less than or equal to comparison") } - case "<": throw JinjaError.syntaxNotSupported("<") - case ">": throw JinjaError.syntaxNotSupported(">") - case ">=": throw JinjaError.syntaxNotSupported(">=") - case "<=": throw JinjaError.syntaxNotSupported("<=") default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } - } else if left is ArrayValue && right is ArrayValue { + } else if let left = left as? ArrayValue, let right = right as? ArrayValue { switch node.operation.value { - case "+": break + case "+": + return ArrayValue(value: left.value + right.value) default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } - } else if right is ArrayValue { - throw JinjaError.syntaxNotSupported("right is ArrayValue") - } - - if left is StringValue || right is StringValue { - switch node.operation.value { - case "+": - var rightValue = "" - var leftValue = "" - switch right.value { - case let value as String: - rightValue = value - case let value as Int: - rightValue = String(value) - case let value as Bool: - rightValue = String(value) - default: - throw JinjaError.runtime("Unknown right value type:\(type(of: right.value))") + } else if let right = right as? ArrayValue { + let member: Bool + if let left = left as? StringValue { + member = right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false } - - switch left.value { - case let value as String: - leftValue = value - case let value as Int: - leftValue = String(value) - case let value as Bool: - rightValue = String(value) - default: - throw JinjaError.runtime("Unknown left value type:\(type(of: left.value))") + } else if let left = left as? NumericValue { + member = right.value.contains { + if let item = $0 as? NumericValue { + return item.value as! Int == left.value as! Int + } + return false } - - return StringValue(value: leftValue + rightValue) - default: - break + } else if let left = left as? BooleanValue { + member = right.value.contains { + if let item = $0 as? BooleanValue { + return item.value == left.value + } + return false + } + } else { + throw JinjaError.runtime("Unsupported left type for 'in'/'not in' operation with ArrayValue") } - } - if let left = left as? StringValue, let right = right as? StringValue { switch node.operation.value { case "in": - return BooleanValue(value: right.value.contains(left.value)) + return BooleanValue(value: member) case "not in": - return BooleanValue(value: !right.value.contains(left.value)) + return BooleanValue(value: !member) default: throw JinjaError.runtime("Unknown operation type:\(node.operation.value)") } } - if left is StringValue, right is ObjectValue { + if let left = left as? StringValue { switch node.operation.value { + case "+": + let rightValue: String + if let rightString = right as? StringValue { + rightValue = rightString.value + } else if let rightNumeric = right as? NumericValue { + rightValue = String(describing: rightNumeric.value) + } else if let rightBoolean = right as? BooleanValue { + rightValue = String(rightBoolean.value) + } else if right is UndefinedValue { + rightValue = "" + } else { + throw JinjaError.runtime("Unsupported right operand type for string concatenation") + } + return StringValue(value: left.value + rightValue) case "in": - if let leftString = (left as? StringValue)?.value, - let rightObject = right as? ObjectValue - { - return BooleanValue(value: rightObject.value.keys.contains(leftString)) + if let right = right as? StringValue { + return BooleanValue(value: right.value.contains(left.value)) + } else if let right = right as? ObjectValue { + return BooleanValue(value: right.value.keys.contains(left.value)) + } else if let right = right as? ArrayValue { + return BooleanValue( + value: right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false + } + ) + } else { + throw JinjaError.runtime("Right operand of 'in' must be a StringValue, ArrayValue, or ObjectValue") } case "not in": - if let leftString = (left as? StringValue)?.value, - let rightObject = right as? ObjectValue - { - return BooleanValue(value: !rightObject.value.keys.contains(leftString)) + if let right = right as? StringValue { + return BooleanValue(value: !right.value.contains(left.value)) + } else if let right = right as? ObjectValue { + return BooleanValue(value: !right.value.keys.contains(left.value)) + } else if let right = right as? ArrayValue { + return BooleanValue( + value: !right.value.contains { + if let item = $0 as? StringValue { + return item.value == left.value + } + return false + } + ) + } else { + throw JinjaError.runtime( + "Right operand of 'not in' must be a StringValue, ArrayValue, or ObjectValue" + ) } + default: + break + } + } else if let right = right as? StringValue { + if node.operation.value == "+" { + if let leftString = left as? StringValue { + return StringValue(value: leftString.value + right.value) + } else if let leftNumeric = left as? NumericValue { + return StringValue(value: String(describing: leftNumeric.value) + right.value) + } else if let leftBoolean = left as? BooleanValue { + return StringValue(value: String(leftBoolean.value) + right.value) + } else { + throw JinjaError.runtime("Unsupported left operand type for string concatenation") + } + } + } + + if let left = left as? StringValue, let right = right as? ObjectValue { + switch node.operation.value { + case "in": + return BooleanValue(value: right.value.keys.contains(left.value)) + case "not in": + return BooleanValue(value: !right.value.keys.contains(left.value)) default: throw JinjaError.runtime( "Unsupported operation '\(node.operation.value)' between StringValue and ObjectValue" @@ -463,19 +779,19 @@ struct Interpreter { return ArrayValue( value: slice( object.value, - start: start.value as? Int, - stop: stop.value as? Int, - step: step.value as? Int + start: (start as? NumericValue)?.value as? Int, + stop: (stop as? NumericValue)?.value as? Int, + step: (step as? NumericValue)?.value as? Int ) ) } else if let object = object as? StringValue { return StringValue( value: slice( - Array(arrayLiteral: object.value), - start: start.value as? Int, - stop: stop.value as? Int, - step: step.value as? Int - ).joined() + Array(object.value), + start: (start as? NumericValue)?.value as? Int, + stop: (stop as? NumericValue)?.value as? Int, + step: (step as? NumericValue)?.value as? Int + ).map { String($0) }.joined() ) } @@ -484,7 +800,6 @@ struct Interpreter { func evaluateMemberExpression(expr: MemberExpression, environment: Environment) throws -> any RuntimeValue { let object = try self.evaluate(statement: expr.object, environment: environment) - var property: any RuntimeValue if expr.computed { if let property = expr.property as? SliceExpression { @@ -495,7 +810,6 @@ struct Interpreter { } else { property = StringValue(value: (expr.property as! Identifier).value) } - var value: (any RuntimeValue)? if let object = object as? ObjectValue { if let property = property as? StringValue { @@ -503,34 +817,55 @@ struct Interpreter { } else { throw JinjaError.runtime("Cannot access property with non-string: got \(type(of:property))") } - } else if object is ArrayValue || object is StringValue { + } else if let object = object as? ArrayValue { if let property = property as? NumericValue { - if let object = object as? ArrayValue { - let index = property.value as! Int - if index >= 0 { + if let index = property.value as? Int { + if index >= 0 && index < object.value.count { value = object.value[index] - } else { + } else if index < 0 && index >= -object.value.count { value = object.value[object.value.count + index] + } else { + value = UndefinedValue() } - } else if let object = object as? StringValue { - let index = object.value.index(object.value.startIndex, offsetBy: property.value as! Int) - value = StringValue(value: String(object.value[index])) + } else { + throw JinjaError.runtime("Array index must be an integer") } } else if let property = property as? StringValue { value = object.builtins[property.value] } else { throw JinjaError.runtime( - "Cannot access property with non-string/non-number: got \(type(of:property))" + "Cannot access property with non-string/non-number: got \(type(of: property))" + ) + } + } else if let object = object as? StringValue { + if let property = property as? NumericValue { + if let index = property.value as? Int { + if index >= 0 && index < object.value.count { + let strIndex = object.value.index(object.value.startIndex, offsetBy: index) + value = StringValue(value: String(object.value[strIndex])) + } else if index < 0 && index >= -object.value.count { + let strIndex = object.value.index(object.value.startIndex, offsetBy: object.value.count + index) + value = StringValue(value: String(object.value[strIndex])) + } else { + value = UndefinedValue() + } + } else { + throw JinjaError.runtime("String index must be an integer") + } + } else if let property = property as? StringValue { + value = object.builtins[property.value] + } else { + throw JinjaError.runtime( + "Cannot access property with non-string/non-number: got \(type(of: property))" ) } } else { if let property = property as? StringValue { - value = object.builtins[property.value]! + value = object.builtins[property.value] } else { throw JinjaError.runtime("Cannot access property with non-string: got \(type(of:property))") } } - if let value { return value } else { @@ -561,7 +896,7 @@ struct Interpreter { } } - if kwargs.count > 0 { + if !kwargs.isEmpty { args.append(ObjectValue(value: kwargs)) } @@ -575,9 +910,11 @@ struct Interpreter { } func evaluateFilterExpression(node: FilterExpression, environment: Environment) throws -> any RuntimeValue { - let operand = try evaluate(statement: node.operand, environment: environment) - + let operand = try self.evaluate(statement: node.operand, environment: environment) if let identifier = node.filter as? Identifier { + if identifier.value == "tojson" { + return try StringValue(value: toJSON(operand)) + } if let arrayValue = operand as? ArrayValue { switch identifier.value { case "list": @@ -591,7 +928,32 @@ struct Interpreter { case "reverse": return ArrayValue(value: arrayValue.value.reversed()) case "sort": - throw JinjaError.todo("TODO: ArrayValue filter sort") + return ArrayValue( + value: try arrayValue.value.sorted { + // No need to cast to AnyComparable here + if let a = $0 as? NumericValue, let b = $1 as? NumericValue { + if let aInt = a.value as? Int, let bInt = b.value as? Int { + return aInt < bInt + } else if let aDouble = a.value as? Double, let bDouble = b.value as? Double { + return aDouble < bDouble + } else if let aInt = a.value as? Int, let bDouble = b.value as? Double { + return Double(aInt) < bDouble + } else if let aDouble = a.value as? Double, let bInt = b.value as? Int { + return aDouble < Double(bInt) + } else { + throw JinjaError.runtime("Unsupported numeric types for comparison") + } + } else if let a = $0 as? StringValue, let b = $1 as? StringValue { + return a.value < b.value + } else { + throw JinjaError.runtime( + "Cannot compare values of different types or non-comparable types" + ) + } + } + ) + case "map": + throw JinjaError.todo("TODO: ArrayValue filter map") default: throw JinjaError.runtime("Unknown ArrayValue filter: \(identifier.value)") } @@ -604,34 +966,38 @@ struct Interpreter { case "lower": return StringValue(value: stringValue.value.lowercased()) case "title": - return StringValue(value: stringValue.value.capitalized) + return StringValue(value: stringValue.value.titleCase()) case "capitalize": - return StringValue(value: stringValue.value.capitalized) + return StringValue(value: stringValue.value.prefix(1).uppercased() + stringValue.value.dropFirst()) case "trim": return StringValue(value: stringValue.value.trimmingCharacters(in: .whitespacesAndNewlines)) + case "indent": + return StringValue(value: stringValue.value.indent(4)) + case "string": + return stringValue default: throw JinjaError.runtime("Unknown StringValue filter: \(identifier.value)") } } else if let numericValue = operand as? NumericValue { switch identifier.value { case "abs": - return NumericValue(value: abs(numericValue.value as! Int32)) + if let intValue = numericValue.value as? Int { + return NumericValue(value: abs(intValue)) + } else if let doubleValue = numericValue.value as? Double { + return NumericValue(value: abs(doubleValue)) + } else { + throw JinjaError.runtime("Unsupported numeric type for abs filter") + } default: throw JinjaError.runtime("Unknown NumericValue filter: \(identifier.value)") } } else if let objectValue = operand as? ObjectValue { switch identifier.value { case "items": - var items: [ArrayValue] = [] - for (k, v) in objectValue.value { - items.append( - ArrayValue(value: [ - StringValue(value: k), - v, - ]) - ) + let items: [ArrayValue] = objectValue.value.map { (key, value) in + return ArrayValue(value: [StringValue(value: key), value]) } - return items as! (any RuntimeValue) + return ArrayValue(value: items) case "length": return NumericValue(value: objectValue.value.count) default: @@ -639,9 +1005,132 @@ struct Interpreter { } } - throw JinjaError.runtime("Cannot apply filter \(operand.value) to type: \(type(of:operand))") - } + throw JinjaError.runtime("Cannot apply filter \(identifier.value) to type: \(type(of: operand))") + } else if let callExpression = node.filter as? CallExpression { + if let identifier = callExpression.callee as? Identifier { + let filterName = identifier.value + + if filterName == "tojson" { + let args = try self.evaluateArguments(args: callExpression.args, environment: environment) + let indent = args.1["indent"] ?? NullValue() + + if let indentNumeric = indent as? NumericValue { + if let indentInt = indentNumeric.value as? Int { + return try StringValue(value: toJSON(operand, indent: indentInt)) + } else if let indentDouble = indentNumeric.value as? Double { + return try StringValue(value: toJSON(operand, indent: Int(indentDouble))) + } else { + throw JinjaError.runtime("If set, indent must be a number") + } + } else if indent is NullValue { + return try StringValue(value: toJSON(operand)) + } else { + throw JinjaError.runtime("If set, indent must be a number") + } + } + if let arrayValue = operand as? ArrayValue { + switch filterName { + case "selectattr", "rejectattr": + let select = filterName == "selectattr" + if arrayValue.value.contains(where: { !($0 is ObjectValue) }) { + throw JinjaError.runtime("`\(filterName)` can only be applied to array of objects") + } + if callExpression.args.contains(where: { !($0 is StringLiteral) }) { + throw JinjaError.runtime("arguments of `\(filterName)` must be strings") + } + let args = try callExpression.args.map { arg -> StringValue in + let evaluatedArg = try self.evaluate(statement: arg, environment: environment) + guard let stringValue = evaluatedArg as? StringValue else { + throw JinjaError.runtime("Arguments of `\(filterName)` must be strings") + } + return stringValue + } + let attr = args[0] + let testName = args.count > 1 ? args[1] : nil + let value = args.count > 2 ? args[2] : nil + var testFunction: ((any RuntimeValue, StringValue?) throws -> Bool) + if let testName = testName { + guard let test = environment.tests[testName.value] else { + throw JinjaError.runtime("Unknown test: \(testName.value)") + } + testFunction = { a, b in + try test(a, b ?? UndefinedValue()) + } + } else { + testFunction = { a, _ in + a.bool() + } + } + let filtered = (arrayValue.value as! [ObjectValue]).filter { item in + let a = item.value[attr.value] + let result = a != nil ? try! testFunction(a!, value) : false + return select ? result : !result + } + return ArrayValue(value: filtered) + case "map": + let evaluatedArgs = try self.evaluateArguments( + args: callExpression.args, + environment: environment + ) + let kwargs = evaluatedArgs.1 + if let attribute = kwargs["attribute"] { + let defaultValue = kwargs["default"] + let mapped = try arrayValue.value.map { item -> Any in + guard let objectValue = item as? ObjectValue else { + throw JinjaError.runtime("Items in map must be objects") + } + if let attributeString = attribute as? StringValue { + let result = + objectValue.value[attributeString.value] ?? defaultValue ?? UndefinedValue() + return result + } else { + throw JinjaError.runtime("`map` filter attribute must be a string") + } + } + return ArrayValue(value: mapped.map { $0 as! (any RuntimeValue) }) + } else { + // TODO: Implement map filter without attribute argument + // This will likely involve applying a filter function to each element. + throw JinjaError.runtime("`map` filter without `attribute` is not yet supported.") + } + default: + throw JinjaError.runtime("Unknown ArrayValue filter: \(filterName)") + } + } else if let stringValue = operand as? StringValue { + switch filterName { + case "indent": + let args = try self.evaluateArguments(args: callExpression.args, environment: environment) + let positionalArgs = args.0 + let kwargs = args.1 + let width = positionalArgs.first ?? kwargs["width"] ?? NumericValue(value: 4) + if !(width is NumericValue) { + throw JinjaError.runtime("width must be a number") + } + let first = + positionalArgs.count > 1 ? positionalArgs[1] : kwargs["first"] ?? BooleanValue(value: false) + let blank = + positionalArgs.count > 2 ? positionalArgs[2] : kwargs["blank"] ?? BooleanValue(value: false) + guard let widthInt = (width as? NumericValue)?.value as? Int else { + throw JinjaError.runtime("width must be an integer") + } + return StringValue( + value: stringValue.value.indent( + widthInt, + first: first.bool(), + blank: blank.bool() + ) + ) + default: + throw JinjaError.runtime("Unknown StringValue filter: \(filterName)") + } + } else { + throw JinjaError.runtime("Cannot apply filter '\(filterName)' to type: \(type(of: operand))") + } + } else { + throw JinjaError.runtime("Unknown filter: \(callExpression.callee)") + } + } throw JinjaError.runtime("Unknown filter: \(node.filter)") } @@ -656,6 +1145,76 @@ struct Interpreter { } } + func evaluateMacro(node: Macro, environment: Environment) throws -> NullValue { + try environment.setVariable( + name: node.name.value, + value: FunctionValue(value: { args, scope in + let macroScope = Environment(parent: scope) + + var args = args + var kwargs: [String: any RuntimeValue] = [:] + + if let lastArg = args.last, let keywordArgsValue = lastArg as? KeywordArgumentsValue { + kwargs = keywordArgsValue.value + args.removeLast() + } + + for i in 0 ..< node.args.count { + let nodeArg = node.args[i] + let passedArg = args.count > i ? args[i] : nil + + if let identifier = nodeArg as? Identifier { + if passedArg == nil { + if let defaultValue = kwargs[identifier.value] { + try macroScope.setVariable(name: identifier.value, value: defaultValue) + } else { + throw JinjaError.runtime("Missing argument: \(identifier.value)") + } + } else { + try macroScope.setVariable(name: identifier.value, value: passedArg!) + } + } else if let kwarg = nodeArg as? KeywordArgumentExpression { + let value = + try kwargs[kwarg.key.value] + ?? (passedArg ?? (try self.evaluate(statement: kwarg.value, environment: macroScope))) + + try macroScope.setVariable(name: kwarg.key.value, value: value) + } else { + throw JinjaError.runtime("Unknown argument type: \(type(of: nodeArg))") + } + } + + return try self.evaluateBlock(statements: node.body, environment: macroScope) + }) + ) + + return NullValue() + } + + func evaluateArguments( + args: [Expression], + environment: Environment + ) throws -> ([any RuntimeValue], [String: any RuntimeValue]) { + var positionalArguments: [any RuntimeValue] = [] + var keywordArguments: [String: any RuntimeValue] = [:] + + for argument in args { + if let keywordArgument = argument as? KeywordArgumentExpression { + keywordArguments[keywordArgument.key.value] = try self.evaluate( + statement: keywordArgument.value, + environment: environment + ) + } else { + if !keywordArguments.isEmpty { + throw JinjaError.runtime("Positional arguments must come before keyword arguments") + } + positionalArguments.append(try self.evaluate(statement: argument, environment: environment)) + } + } + + return (positionalArguments, keywordArguments) + } + func evaluate(statement: Statement?, environment: Environment) throws -> any RuntimeValue { if let statement { switch statement { @@ -678,7 +1237,13 @@ struct Interpreter { case let statement as UnaryExpression: return try self.evaluateUnaryExpression(node: statement, environment: environment) case let statement as NumericLiteral: - return NumericValue(value: statement.value) + if let intValue = statement.value as? Int { + return NumericValue(value: intValue) + } else if let doubleValue = statement.value as? Double { + return NumericValue(value: doubleValue) + } else { + throw JinjaError.runtime("Invalid numeric literal value") + } case let statement as CallExpression: return try self.evaluateCallExpression(expr: statement, environment: environment) case let statement as BoolLiteral: @@ -687,6 +1252,22 @@ struct Interpreter { return try self.evaluateFilterExpression(node: statement, environment: environment) case let statement as TestExpression: return try self.evaluateTestExpression(node: statement, environment: environment) + case let statement as ArrayLiteral: + return ArrayValue( + value: try statement.value.map { try self.evaluate(statement: $0, environment: environment) } + ) + case let statement as TupleLiteral: + return TupleValue( + value: try statement.value.map { try self.evaluate(statement: $0, environment: environment) } + ) + case let statement as ObjectLiteral: + var mapping: [String: any RuntimeValue] = [:] + for (key, value) in statement.value { + mapping[key] = try self.evaluate(statement: value, environment: environment) + } + return ObjectValue(value: mapping) + case let statement as Macro: + return try self.evaluateMacro(node: statement, environment: environment) case is NullLiteral: return NullValue() default: diff --git a/Sources/Utilities.swift b/Sources/Utilities.swift index c01870b..91a0012 100644 --- a/Sources/Utilities.swift +++ b/Sources/Utilities.swift @@ -38,3 +38,99 @@ func slice(_ array: [T], start: Int? = nil, stop: Int? = nil, step: Int? = 1) return slicedArray } + +func toJSON(_ input: any RuntimeValue, indent: Int? = nil, depth: Int = 0) throws -> String { + let currentDepth = depth + + switch input { + case is NullValue, is UndefinedValue: + return "null" + + case let value as NumericValue: + return String(describing: value.value) + + case let value as StringValue: + // Properly escape special characters for JSON strings + let escapedValue = value.value + .replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "\"", with: "\\\"") + .replacingOccurrences(of: "\n", with: "\\n") + .replacingOccurrences(of: "\r", with: "\\r") + .replacingOccurrences(of: "\t", with: "\\t") + return "\"\(escapedValue)\"" + + case let value as BooleanValue: + return value.value ? "true" : "false" + + case let arr as ArrayValue: + let indentValue = indent != nil ? String(repeating: " ", count: indent!) : "" + let basePadding = indent != nil ? "\n" + String(repeating: indentValue, count: currentDepth) : "" + let childrenPadding = indent != nil ? basePadding + indentValue : "" + + let core = try arr.value.map { try toJSON($0, indent: indent, depth: currentDepth + 1) } + + if indent != nil { + return "[\(childrenPadding)\(core.joined(separator: ",\(childrenPadding)"))\(basePadding)]" + } else { + return "[\(core.joined(separator: ", "))]" + } + + case let obj as ObjectValue: + let indentValue = indent != nil ? String(repeating: " ", count: indent!) : "" + let basePadding = indent != nil ? "\n" + String(repeating: indentValue, count: currentDepth) : "" + let childrenPadding = indent != nil ? basePadding + indentValue : "" + + // Use orderedKeys to maintain insertion order + let pairs = try obj.orderedKeys.map { key in + guard let value = obj.value[key] else { + throw JinjaError.runtime("Missing value for key: \(key)") + } + let jsonValue = try toJSON(value, indent: indent, depth: depth + 1) + return "\"\(key)\": \(jsonValue)" + } + + if indent != nil { + return "{\(childrenPadding)\(pairs.joined(separator: ",\(childrenPadding)"))\(basePadding)}" + } else { + return "{\(pairs.joined(separator: ", "))}" + } + default: + throw JinjaError.runtime("Cannot convert to JSON: \(type(of: input))") + } +} + +// Helper function to convert values to JSON strings +func jsonString(_ value: Any) throws -> String { + let data = try JSONSerialization.data(withJSONObject: value) + guard let string = String(data: data, encoding: .utf8) else { + throw JinjaError.runtime("Failed to convert value to JSON string") + } + return string +} + +extension String { + func titleCase() -> String { + return self.components(separatedBy: .whitespacesAndNewlines) + .map { word in + guard let firstChar = word.first else { return "" } + return String(firstChar).uppercased() + word.dropFirst() + } + .joined(separator: " ") + } + + func indent(_ width: Int, first: Bool = false, blank: Bool = false) -> String { + let indentString = String(repeating: " ", count: width) + return self.components(separatedBy: .newlines) + .enumerated() + .map { (index, line) in + if line.isEmpty && !blank { + return line + } + if index == 0 && !first { + return line + } + return indentString + line + } + .joined(separator: "\n") + } +} diff --git a/Tests/ChatTemplateTests.swift b/Tests/ChatTemplateTests.swift index 4b9ab6b..ef0490c 100644 --- a/Tests/ChatTemplateTests.swift +++ b/Tests/ChatTemplateTests.swift @@ -9,7 +9,12 @@ import XCTest @testable import Jinja -let messages: [[String: String]] = [ +let llama3_2visionChatTemplate = + "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- Find out if there are any images #}\n{% set image_ns = namespace(has_images=false) %} \n{%- for message in messages %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {%- set image_ns.has_images = true %}\n {%- endif %}\n {%- endfor %}\n{%- endfor %}\n\n{#- Error out if there are images and system message #}\n{%- if image_ns.has_images and not system_message == \"\" %}\n {{- raise_exception(\"Prompting with images is incompatible with system messages.\") }}\n{%- endif %}\n\n{#- System message if there are no images #}\n{%- if not image_ns.has_images %}\n {{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n {%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n {%- endif %}\n {{- \"Cutting Knowledge Date: December 2023\\n\" }}\n {{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n {%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {%- endif %}\n {{- system_message }}\n {{- \"<|eot_id|>\" }}\n{%- endif %}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' }}\n {%- if message['content'] is string %}\n {{- message['content'] }}\n {%- else %}\n {%- for content in message['content'] %}\n {%- if content['type'] == 'image' %}\n {{- '<|image|>' }}\n {%- elif content['type'] == 'text' %}\n {{- content['text'] }}\n {%- endif %}\n {%- endfor %}\n {%- endif %}\n {{- '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" +let qwen2VLChatTemplate = + "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n{% endif %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" + +let exampleChatMessages: [[String: String]] = [ [ "role": "user", "content": "Hello, how are you?", @@ -24,27 +29,29 @@ let messages: [[String: String]] = [ ], ] -let messagesWithSystem: [[String: String]] = +let exampleChatMessagesWithSystemPrompt: [[String: String]] = [ [ "role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate", ] - ] + messages + ] + exampleChatMessages final class ChatTemplateTests: XCTestCase { struct Test { + let name: String let chatTemplate: String let data: [String: Any] let target: String } - let defaultTemplates: [Test] = [ + let defaultTemplateTests: [Test] = [ Test( + name: "Generic chat template with messages", chatTemplate: "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", data: [ - "messages": messages, + "messages": exampleChatMessages, "add_generation_prompt": false, ], target: @@ -52,10 +59,11 @@ final class ChatTemplateTests: XCTestCase { ), // facebook/blenderbot-400M-distill Test( + name: "facebook/blenderbot-400M-distill", chatTemplate: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", data: [ - "messages": messages, + "messages": exampleChatMessages, "eos_token": "", ], target: @@ -63,10 +71,11 @@ final class ChatTemplateTests: XCTestCase { ), // facebook/blenderbot_small-90M Test( + name: "facebook/blenderbot_small-90M", chatTemplate: "{% for message in messages %}{% if message['role'] == 'user' %}{{ ' ' }}{% endif %}{{ message['content'] }}{% if not loop.last %}{{ ' ' }}{% endif %}{% endfor %}{{ eos_token }}", data: [ - "messages": messages, + "messages": exampleChatMessages, "eos_token": "", ], target: @@ -74,9 +83,10 @@ final class ChatTemplateTests: XCTestCase { ), // bigscience/bloom Test( + name: "bigscience/bloom", chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", data: [ - "messages": messages, + "messages": exampleChatMessages, "eos_token": "", ], target: @@ -84,19 +94,21 @@ final class ChatTemplateTests: XCTestCase { ), // EleutherAI/gpt-neox-20b Test( + name: "EleutherAI/gpt-neox-20b", chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", data: [ - "messages": messages, + "messages": exampleChatMessages, "eos_token": "<|endoftext|>", ], target: "Hello, how are you?<|endoftext|>I'm doing great. How can I help you today?<|endoftext|>I'd like to show off how chat templating works!<|endoftext|>" ), - // gpt2 + // GPT-2 Test( + name: "GPT-2", chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", data: [ - "messages": messages, + "messages": exampleChatMessages, "eos_token": "<|endoftext|>", ], target: @@ -104,10 +116,11 @@ final class ChatTemplateTests: XCTestCase { ), // hf-internal-testing/llama-tokenizer Test( + name: "hf-internal-testing/llama-tokenizer 1", chatTemplate: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", data: [ - "messages": messagesWithSystem, + "messages": exampleChatMessagesWithSystemPrompt, "bos_token": "", "eos_token": "", "USE_DEFAULT_PROMPT": true, @@ -117,10 +130,11 @@ final class ChatTemplateTests: XCTestCase { ), // hf-internal-testing/llama-tokenizer Test( + name: "hf-internal-testing/llama-tokenizer 2", chatTemplate: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", data: [ - "messages": messages, + "messages": exampleChatMessages, "bos_token": "", "eos_token": "", "USE_DEFAULT_PROMPT": true, @@ -130,6 +144,7 @@ final class ChatTemplateTests: XCTestCase { ), // hf-internal-testing/llama-tokenizer Test( + name: "hf-internal-testing/llama-tokenizer 3", chatTemplate: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}", data: [ @@ -156,9 +171,10 @@ final class ChatTemplateTests: XCTestCase { ), // openai/whisper-large-v3 Test( + name: "openai/whisper-large-v3", chatTemplate: "{% for message in messages %}{{ message.content }}{{ eos_token }}{% endfor %}", data: [ - "messages": messages, + "messages": exampleChatMessages, "eos_token": "<|endoftext|>", ], target: @@ -166,10 +182,11 @@ final class ChatTemplateTests: XCTestCase { ), // Qwen/Qwen1.5-1.8B-Chat Test( + name: "Qwen/Qwen1.5-1.8B-Chat", chatTemplate: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", data: [ - "messages": messages, + "messages": exampleChatMessages, "add_generation_prompt": true, ], target: @@ -177,10 +194,11 @@ final class ChatTemplateTests: XCTestCase { ), // Qwen/Qwen1.5-1.8B-Chat Test( + name: "Qwen/Qwen1.5-1.8B-Chat 2", chatTemplate: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", data: [ - "messages": messagesWithSystem, + "messages": exampleChatMessagesWithSystemPrompt, "add_generation_prompt": true, ], target: @@ -188,51 +206,425 @@ final class ChatTemplateTests: XCTestCase { ), // Qwen/Qwen1.5-1.8B-Chat Test( + name: "Qwen/Qwen1.5-1.8B-Chat 3", chatTemplate: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}", data: [ - "messages": messagesWithSystem + "messages": exampleChatMessagesWithSystemPrompt ], target: "<|im_start|>system\nYou are a friendly chatbot who always responds in the style of a pirate<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!" ), // THUDM/chatglm3-6b Test( + name: "THUDM/chatglm3-6b", chatTemplate: "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}", data: [ - "messages": messagesWithSystem + "messages": exampleChatMessagesWithSystemPrompt ], target: "[gMASK]sop<|system|>\n You are a friendly chatbot who always responds in the style of a pirate<|user|>\n Hello, how are you?<|assistant|>\n I\'m doing great. How can I help you today?<|user|>\n I\'d like to show off how chat templating works!" ), // google/gemma-2b-it Test( + name: "google/gemma-2b-it", chatTemplate: "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", data: [ - "messages": messages + "messages": exampleChatMessages ], target: "user\nHello, how are you?\nmodel\nI\'m doing great. How can I help you today?\nuser\nI\'d like to show off how chat templating works!\n" ), // Qwen/Qwen2.5-0.5B-Instruct Test( + name: "Qwen/Qwen2.5-0.5B-Instruct", chatTemplate: "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", data: [ - "messages": messages + "messages": exampleChatMessages ], target: "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI\'m doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI\'d like to show off how chat templating works!<|im_end|>\n" ), + // Llama-3.2-11B-Vision-Instruct: text chat only + Test( + name: "Llama-3.2-11B-Vision-Instruct: text chat only", + chatTemplate: llama3_2visionChatTemplate, + data: [ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "Hello, how are you?", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + [ + "role": "assistant", + "content": [ + [ + "type": "text", + "text": "I'm doing great. How can I help you today?", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "I'd like to show off how chat templating works!", + ] as [String: Any] + ] as [[String: Any]], + ] as [String: Any], + ] as [[String: Any]] as Any, + "bos_token": "" as Any, + "date_string": "26 Jul 2024" as Any, + "tools_in_user_message": true as Any, + "system_message": "You are a helpful assistant." as Any, + "add_generation_prompt": true as Any, + ], + target: + "\n<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nHello, how are you?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nI'm doing great. How can I help you today?<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nI'd like to show off how chat templating works!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ), + // Llama-3.2-11B-Vision-Instruct: with images + Test( + name: "Llama-3.2-11B-Vision-Instruct: with images", + chatTemplate: llama3_2visionChatTemplate, + data: [ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's in this image?", + ] as [String: Any], + [ + "type": "image", + "image": "base64_encoded_image_data", + ] as [String: Any], + ] as [[String: Any]], + ] as [String: Any] + ] as [[String: Any]] as Any, + "bos_token": "" as Any, + "add_generation_prompt": true as Any, + ], + target: + "\n<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat's in this image?<|image|><|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ), + // Qwen2-VL text only + Test( + name: "Qwen2-VL-7B-Instruct: text only", + chatTemplate: qwen2VLChatTemplate, + data: [ + "messages": exampleChatMessages, + "add_generation_prompt": true, + ], + target: """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + Hello, how are you?<|im_end|> + <|im_start|>assistant + I'm doing great. How can I help you today?<|im_end|> + <|im_start|>user + I'd like to show off how chat templating works!<|im_end|> + <|im_start|>assistant + + """ + ), + // Qwen2-VL with images + Test( + name: "Qwen2-VL-7B-Instruct: with images", + chatTemplate: qwen2VLChatTemplate, + data: [ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's in this image?", + ] as [String: String], + [ + "type": "image", + "image_url": "example.jpg", + ] as [String: String], + ] as [[String: String]], + ] as [String: Any] + ] as [[String: Any]], + "add_generation_prompt": true, + "add_vision_id": true, + ], + target: """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What's in this image?Picture 1: <|vision_start|><|image_pad|><|vision_end|><|im_end|> + <|im_start|>assistant + + """ + ), + // Qwen2-VL with video + Test( + name: "Qwen2-VL-7B-Instruct: with video", + chatTemplate: qwen2VLChatTemplate, + data: [ + "messages": [ + [ + "role": "user", + "content": [ + [ + "type": "text", + "text": "What's happening in this video?", + ] as [String: String], + [ + "type": "video", + "video_url": "example.mp4", + ] as [String: String], + ] as [[String: String]], + ] as [String: Any] + ] as [[String: Any]], + "add_generation_prompt": true, + "add_vision_id": true, + ], + target: """ + <|im_start|>system + You are a helpful assistant.<|im_end|> + <|im_start|>user + What's happening in this video?Video 1: <|vision_start|><|video_pad|><|vision_end|><|im_end|> + <|im_start|>assistant + + """ + ), ] func testDefaultTemplates() throws { - for test in defaultTemplates { + for test in defaultTemplateTests { let template = try Template(test.chatTemplate) let result = try template.render(test.data) + if result != test.target { + print("Test for \(test.name) failed") + print("Target:") + print(test.target) + print("Result:") + print(result) + } XCTAssertEqual(result.debugDescription, test.target.debugDescription) } } + + func testCustomTemplates() throws { + let tests = [ + Test( + name: "HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=false)", + chatTemplate: + "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + data: [ + "messages": exampleChatMessagesWithSystemPrompt, "eos_token": "", + "add_generation_prompt": false, + ] + as [String: Any], + target: + "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate\n<|user|>\nHello, how are you?\n<|assistant|>\nI'm doing great. How can I help you today?\n<|user|>\nI'd like to show off how chat templating works!\n" + ), + Test( + name: "HuggingFaceH4/zephyr-7b-beta (add_generation_prompt=true)", + chatTemplate: + "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}", + data: [ + "messages": [ + [ + "role": "system", + "content": "You are a friendly chatbot who always responds in the style of a pirate", + ], + ["role": "user", "content": "How many helicopters can a human eat in one sitting?"], + ], "eos_token": "", "add_generation_prompt": true, + ] as [String: Any], + target: + "<|system|>\nYou are a friendly chatbot who always responds in the style of a pirate\n<|user|>\nHow many helicopters can a human eat in one sitting?\n<|assistant|>\n" + ), + Test( + name: "HuggingFaceH4/zephyr-7b-gemma-v0.1", + chatTemplate: + "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}", + data: [ + "messages": exampleChatMessages, "bos_token": "", "eos_token": "", + "add_generation_prompt": false, + ] as [String: Any], + target: + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + ), + Test( + name: "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ", + chatTemplate: + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + data: ["messages": exampleChatMessages, "bos_token": "", "eos_token": ""] as [String: Any], + target: + "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + ), + Test( + name: "mistralai/Mixtral-8x7B-Instruct-v0.1", + chatTemplate: + "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + data: ["messages": exampleChatMessages, "bos_token": "", "eos_token": ""] as [String: Any], + target: + "[INST] Hello, how are you? [/INST]I'm doing great. How can I help you today?[INST] I'd like to show off how chat templating works! [/INST]" + ), + Test( + name: "cognitivecomputations/dolphin-2.5-mixtral-8x7b", + chatTemplate: + "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + data: ["messages": exampleChatMessages] as [String: Any], + target: + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + ), + Test( + name: "openchat/openchat-3.5-0106", + chatTemplate: + "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}", + data: [ + "messages": exampleChatMessages, "bos_token": "", "eos_token": "", + "add_generation_prompt": false, + ] as [String: Any], + target: + "GPT4 Correct User: Hello, how are you?<|end_of_turn|>GPT4 Correct Assistant: I'm doing great. How can I help you today?<|end_of_turn|>GPT4 Correct User: I'd like to show off how chat templating works!<|end_of_turn|>" + ), + Test( + name: "upstage/SOLAR-10.7B-Instruct-v1.0", + chatTemplate: + "{% for message in messages %}{% if message['role'] == 'system' %}{% if message['content']%}{{'### System:\n' + message['content']+'\n\n'}}{% endif %}{% elif message['role'] == 'user' %}{{'### User:\n' + message['content']+'\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Assistant:\n' + message['content']}}{% endif %}{% if loop.last and add_generation_prompt %}{{ '### Assistant:\n' }}{% endif %}{% endfor %}", + data: ["messages": exampleChatMessages] as [String: Any], + target: + "### User:\nHello, how are you?\n\n### Assistant:\nI'm doing great. How can I help you today?### User:\nI'd like to show off how chat templating works!\n\n" + ), + Test( + name: "codellama/CodeLlama-70b-Instruct-hf", + chatTemplate: + "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\n\n ' + message['content'] | trim %}{{ content + ' ' }}{% endfor %}{{'Source: assistant\nDestination: user\n\n '}}", + data: ["messages": exampleChatMessages] as [String: Any], + target: + "Source: user\n\n Hello, how are you? Source: assistant\n\n I'm doing great. How can I help you today? Source: user\n\n I'd like to show off how chat templating works! Source: assistant\nDestination: user\n\n " + ), + Test( + name: "Deci/DeciLM-7B-instruct", + chatTemplate: + "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '### User:\n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ '### System:\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '### Assistant:\n' + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '### Assistant:' }}\n{% endif %}\n{% endfor %}", + data: ["messages": exampleChatMessages] as [String: Any], + target: + "### User:\nHello, how are you?\n### Assistant:\nI'm doing great. How can I help you today?\n### User:\nI'd like to show off how chat templating works!\n" + ), + Test( + name: "Qwen/Qwen1.5-72B-Chat", + chatTemplate: + "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", + data: ["messages": exampleChatMessages] as [String: Any], + target: + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + ), + Test( + name: "deepseek-ai/deepseek-llm-7b-chat", + chatTemplate: + "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}", + data: [ + "messages": exampleChatMessages, "bos_token": "<|begin of sentence|>", + "eos_token": "<|end of sentence|>", + ] as [String: Any], + target: + "<|begin of sentence|>User: Hello, how are you?\n\nAssistant: I'm doing great. How can I help you today?<|end of sentence|>User: I'd like to show off how chat templating works!\n\n" + ), + Test( + name: "h2oai/h2o-danube-1.8b-chat", + chatTemplate: + "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<|prompt|>' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ '<|system|>' + message['content'] + eos_token }}{% elif message['role'] == 'assistant' %}{{ '<|answer|>' + message['content'] + eos_token }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<|answer|>' }}{% endif %}{% endfor %}", + data: ["messages": exampleChatMessages, "eos_token": ""] as [String: Any], + target: + "<|prompt|>Hello, how are you?<|answer|>I'm doing great. How can I help you today?<|prompt|>I'd like to show off how chat templating works!" + ), + Test( + name: "internlm/internlm2-chat-7b", + chatTemplate: + "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}", + data: ["messages": exampleChatMessages, "bos_token": "", "eos_token": ""] as [String: Any], + target: + "<|im_start|>user\nHello, how are you?<|im_end|>\n<|im_start|>assistant\nI'm doing great. How can I help you today?<|im_end|>\n<|im_start|>user\nI'd like to show off how chat templating works!<|im_end|>\n" + ), + Test( + name: "TheBloke/deepseek-coder-33B-instruct-AWQ", + chatTemplate: + "{%- set found_item = false -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set found_item = true -%}\n {%- endif -%}\n{%- endfor -%}\n{%- if not found_item -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response:\\n'}}\n", + data: ["messages": exampleChatMessages] as [String: Any], + target: + "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n### Response:\n" + ), + Test( + name: "ericzzz/falcon-rw-1b-chat", + chatTemplate: + "{% for message in messages %}{% if loop.index > 1 and loop.previtem['role'] != 'assistant' %}{{ ' ' }}{% endif %}{% if message['role'] == 'system' %}{{ '[SYS] ' + message['content'].strip() }}{% elif message['role'] == 'user' %}{{ '[INST] ' + message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ '[RESP] ' + message['content'] + eos_token }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ ' [RESP] ' }}{% endif %}", + data: ["messages": exampleChatMessages, "eos_token": "<|endoftext|>"] as [String: Any], + target: + "[INST] Hello, how are you? [RESP] I'm doing great. How can I help you today?<|endoftext|>[INST] I'd like to show off how chat templating works!" + ), + Test( + name: "abacusai/Smaug-34B-v0.1", + chatTemplate: + "{%- for idx in range(0, messages|length) -%}\n{%- if messages[idx]['role'] == 'user' -%}\n{%- if idx > 1 -%}\n{{- bos_token + '[INST] ' + messages[idx]['content'] + ' [/INST]' -}}\n{%- else -%}\n{{- messages[idx]['content'] + ' [/INST]' -}}\n{%- endif -%}\n{% elif messages[idx]['role'] == 'system' %}\n{{- '[INST] <>\\n' + messages[idx]['content'] + '\\n<>\\n\\n' -}}\n{%- elif messages[idx]['role'] == 'assistant' -%}\n{{- ' ' + messages[idx]['content'] + ' ' + eos_token -}}\n{% endif %}\n{% endfor %}", + data: ["messages": exampleChatMessages, "bos_token": "", "eos_token": ""] as [String: Any], + target: + "Hello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]" + ), + Test( + name: "maywell/Synatra-Mixtral-8x7B", + chatTemplate: + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\n### Response:\n{% endif %}", + data: ["messages": exampleChatMessages] as [String: Any], + target: + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nHello, how are you?### Response:\nI'm doing great. How can I help you today?### Instruction:\nI'd like to show off how chat templating works!" + ), + Test( + name: "deepseek-ai/deepseek-coder-33b-instruct", + chatTemplate: + "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}", + data: ["messages": exampleChatMessages, "bos_token": "<|begin of sentence|>", "eos_token": "<|EOT|>"] + as [String: Any], + target: + "<|begin of sentence|>You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n### Instruction:\nHello, how are you?\n### Response:\nI'm doing great. How can I help you today?\n<|EOT|>\n### Instruction:\nI'd like to show off how chat templating works!\n" + ), + Test( + name: "maywell/Synatra-Mixtral-8x7B", + chatTemplate: + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n{% for message in messages %}{% if message['role'] == 'user' %}### Instruction:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'assistant' %}### Response:\n{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% elif message['role'] == 'system' %}{{ message['content']|trim -}}{% if not loop.last %}{% endif %}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}\n### Response:\n{% endif %}", + data: ["messages": exampleChatMessagesWithSystemPrompt] as [String: Any], + target: + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\nYou are a friendly chatbot who always responds in the style of a pirate### Instruction:\nHello, how are you?### Response:\nI'm doing great. How can I help you today?### Instruction:\nI'd like to show off how chat templating works!" + ), + Test( + name: "maywell/PiVoT-MoE", + chatTemplate: + "{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}{% for message in messages %}{% if message['role'] == 'system' %}{{ message['content']|trim }}{% elif message['role'] == 'user' %}### Instruction: {{ message['content']|trim }}{% elif message['role'] == 'assistant' %}### Response: {{ message['content']|trim }}{% elif message['role'] == 'user_context' %}### Input: {{ message['content']|trim }}{% endif %}{% if not loop.last %}\n{% endif %}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}### Response:{% endif %}", + data: ["messages": exampleChatMessagesWithSystemPrompt] as [String: Any], + target: + "You are a friendly chatbot who always responds in the style of a pirateYou are a friendly chatbot who always responds in the style of a pirate### Instruction: Hello, how are you?### Response: I'm doing great. How can I help you today?### Instruction: I'd like to show off how chat templating works!" + ), + ] + + for test in tests { + let template = try Template(test.chatTemplate) + let result = try template.render(test.data) + if result != test.target { + print("Test for \(test.name) failed") + print("Target:") + print(test.target) + print("Result:") + print(result) + } + XCTAssertEqual(result, test.target) + } + } + } diff --git a/Tests/InterpreterTests.swift b/Tests/InterpreterTests.swift index d402f84..631d2e6 100644 --- a/Tests/InterpreterTests.swift +++ b/Tests/InterpreterTests.swift @@ -141,17 +141,18 @@ final class InterpreterTests: XCTestCase { for test in tests { let env = Environment() try env.set(name: "True", value: true) - for (key, value) in test.data { try env.set(name: key, value: value) } - let tokens = try tokenize(test.template, options: test.options) let parsed = try parse(tokens: tokens) let interpreter = Interpreter(env: env) - let result = try interpreter.run(program: parsed).value as! String - - XCTAssertEqual(result.debugDescription, test.target.debugDescription) + let result = try interpreter.run(program: parsed) + if let stringResult = result as? StringValue { + XCTAssertEqual(stringResult.value.debugDescription, test.target.debugDescription) + } else { + XCTFail("Expected a StringValue, but got \(type(of: result))") + } } } } diff --git a/Tests/ToolUseTests.swift b/Tests/ToolUseTests.swift new file mode 100644 index 0000000..b19cf0b --- /dev/null +++ b/Tests/ToolUseTests.swift @@ -0,0 +1,397 @@ +import XCTest + +@testable import Jinja + +final class ToolUseTests: XCTestCase { + let exampleFunctionCalling: [[String: Any?]] = [ + [ + "role": "assistant", + "content": nil, + "tool_calls": [ + [ + "type": "function", + "function": [ + "name": "get_current_weather", + "arguments": "{\n \"location\": \"Hanoi\"\n}", + ] as [String: Any?], + ] as [String: Any?] + ] as [[String: Any?]], + ] as [String: Any?], + ["role": "user", "content": "what's the weather like in Hanoi?"] as [String: Any?], + ] + + // Example adapted from https://huggingface.co/fireworks-ai/firefunction-v1 + let exampleFunctionSpec: [[String: Any]] = [ + [ + "name": "get_stock_price", + "description": "Get the current stock price", + "parameters": [ + "type": "object", + "properties": [ + "symbol": [ + "type": "string", + "description": "The stock symbol, e.g. AAPL, GOOG", + ] + ], + "required": ["symbol"], + ], + ] as [String: Any], + [ + "name": "check_word_anagram", + "description": "Check if two words are anagrams of each other", + "parameters": [ + "type": "object", + "properties": [ + "word1": [ + "type": "string", + "description": "The first word", + ], + "word2": [ + "type": "string", + "description": "The second word", + ], + ], + "required": ["word1", "word2"], + ], + ] as [String: Any], + ] + lazy var exampleFunctionCallingWithSystem: [[String: Any]] = [ + ["role": "system", "content": "You are a helpful assistant with access to functions. Use them if required."], + [ + "role": "functions", + "content": String( + data: try! JSONSerialization.data( + withJSONObject: exampleFunctionSpec, + options: [.prettyPrinted, .withoutEscapingSlashes] + ), + encoding: .utf8 + )!, + ], + ["role": "user", "content": "Hi, can you tell me the current stock price of AAPL?"], + ] + + let exampleToolJSONSchemas: [String: [String: Any]] = [ + "get_current_weather": [ + "type": "function", + "function": [ + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": [ + "type": "object", + "properties": [ + "location": [ + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + ], + "unit": [ + "type": "string", + "enum": ["celsius", "fahrenheit"], + ], + ], + "required": ["location"], + ], + ], + ] as [String: Any], + "get_current_temperature_v1": [ + "type": "function", + "function": [ + "name": "get_current_temperature", + "description": "Get the current temperature at a location.", + "parameters": [ + "type": "object", + "properties": [ + "location": [ + "type": "string", + "description": "The location to get the temperature for, in the format \"City, Country\"", + ] + ], + "required": ["location"], + ], + "return": [ + "type": "number", + "description": + "The current temperature at the specified location in the specified units, as a float.", + ], + ], + ] as [String: Any], + "get_current_temperature_v2": [ + "type": "function", + "function": [ + "name": "get_current_temperature", + "description": "Get the current temperature at a location.", + "parameters": [ + "type": "object", + "properties": [ + "location": [ + "type": "string", + "description": "The location to get the temperature for, in the format \"City, Country\"", + ], + "unit": [ + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "The unit to return the temperature in.", + ], + ], + "required": ["location", "unit"], + ], + "return": [ + "type": "number", + "description": + "The current temperature at the specified location in the specified units, as a float.", + ], + ], + ] as [String: Any], + "get_current_wind_speed": [ + "type": "function", + "function": [ + "name": "get_current_wind_speed", + "description": "Get the current wind speed in km/h at a given location.", + "parameters": [ + "type": "object", + "properties": [ + "location": [ + "type": "string", + "description": "The location to get the temperature for, in the format \"City, Country\"", + ] + ], + "required": ["location"], + ], + "return": [ + "type": "number", + "description": "The current wind speed at the given location in km/h, as a float.", + ], + ], + ] as [String: Any], + ] + + lazy var exampleListOfTools: [[String: Any]] = [ + exampleToolJSONSchemas["get_current_temperature_v2"]!, + exampleToolJSONSchemas["get_current_wind_speed"]!, + ] + + func testMeetKaiFunctionaryMediumV2_2() throws { + let chatTemplate = """ + {#v2.2#}\n{% for message in messages %}\n{% if message['role'] == 'user' or message['role'] == 'system' %}\n{{ '<|from|>' + message['role'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}{% elif message['role'] == 'tool' %}\n{{ '<|from|>' + message['name'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}{% else %}\n{% set contain_content='no'%}\n{% if message['content'] is not none %}\n{{ '<|from|>assistant\n<|recipient|>all\n<|content|>' + message['content'] }}{% set contain_content='yes'%}\n{% endif %}\n{% if 'tool_calls' in message and message['tool_calls'] is not none %}\n{% for tool_call in message['tool_calls'] %}\n{% set prompt='<|from|>assistant\n<|recipient|>' + tool_call['function']['name'] + '\n<|content|>' + tool_call['function']['arguments'] %}\n{% if loop.index == 1 and contain_content == "no" %}\n{{ prompt }}{% else %}\n{{ '\n' + prompt}}{% endif %}\n{% endfor %}\n{% endif %}\n{{ '<|stop|>\n' }}{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}{{ '<|from|>assistant\n<|recipient|>' }}{% endif %} + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": exampleFunctionCalling, + "bos_token": "", + "eos_token": "", + "add_generation_prompt": false, + ]) + let target = + """ + <|from|>assistant\n<|recipient|>get_current_weather\n<|content|>{\n "location": "Hanoi"\n}<|stop|>\n<|from|>user\n<|recipient|>all\n<|content|>what's the weather like in Hanoi?\n + """ + + if target != result { + print("::: testMeetKaiFunctionaryMediumV2_2 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testFireworksAIFireFunctionV1() throws { + let chatTemplate = """ + {%- set message_roles = ['SYSTEM', 'FUNCTIONS', 'USER', 'ASSISTANT', 'TOOL'] -%}\n{%- set ns = namespace(seen_non_system=false, messages=messages, content='', functions=[]) -%}\n{{ bos_token }}\n{#- Basic consistency checks -#}\n{%- if not ns.messages -%}\n {{ raise_exception('No messages') }}\n{%- endif -%}\n{%- if ns.messages[0]['role'] | upper != 'SYSTEM' -%}\n {%- set ns.messages = [{'role': 'SYSTEM', 'content': 'You are a helpful assistant with access to functions. Use them if required.'}] + ns.messages -%}\n{%- endif -%}\n{%- if ns.messages | length < 2 or ns.messages[0]['role'] | upper != 'SYSTEM' or ns.messages[1]['role'] | upper != 'FUNCTIONS' -%}\n {{ raise_exception('Expected either "functions" or ["system", "functions"] as the first messages') }}\n{%- endif -%}\n{%- for message in ns.messages -%}\n {%- set role = message['role'] | upper -%}\n {#- Validation -#}\n {%- if role not in message_roles -%}\n {{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles + ' are supported.') }}\n {%- endif -%}\n {%- set ns.content = message['content'] if message.get('content') else '' -%}\n {#- Move tool calls inside the content -#}\n {%- if 'tool_calls' in message -%}\n {%- for call in message['tool_calls'] -%}\n {%- set ns.content = ns.content + '{"name": "' + call['function']['name'] + '", "arguments": ' + call['function']['arguments'] + '}' -%}\n {%- endfor -%}\n {%- endif -%}\n {%- if role == 'ASSISTANT' and '' not in ns.content -%}\n {%- set ns.content = '' + ns.content -%}\n {%- endif -%}\n {%- if role == 'ASSISTANT' -%}\n {%- set ns.content = ns.content + eos_token -%}\n {%- endif -%}\n {{ role }}: {{ ns.content }}{{ '\\n\\n' }}\n{%- endfor -%}\nASSISTANT:{{ ' ' }}\n + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": exampleFunctionCallingWithSystem, + "bos_token": "", + "eos_token": "", + "add_generation_prompt": false, + ]) + let target = """ + SYSTEM: You are a helpful assistant with access to functions. Use them if required.\n\nFUNCTIONS: [\n {\n "name": "get_stock_price",\n "description": "Get the current stock price",\n "parameters": {\n "type": "object",\n "properties": {\n "symbol": {\n "type": "string",\n "description": "The stock symbol, e.g. AAPL, GOOG"\n }\n },\n "required": [\n "symbol"\n ]\n }\n },\n {\n "name": "check_word_anagram",\n "description": "Check if two words are anagrams of each other",\n "parameters": {\n "type": "object",\n "properties": {\n "word1": {\n "type": "string",\n "description": "The first word"\n },\n "word2": {\n "type": "string",\n "description": "The second word"\n }\n },\n "required": [\n "word1",\n "word2"\n ]\n }\n }\n]\n\nSYSTEM: You are a helpful assistant with access to functions. Use them if required.\n\nUSER: Hi, can you tell me the current stock price of AAPL?\n\nASSISTANT: + """ + + if target != result { + print("::: testFireworksAIFireFunctionV1 failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testMistral7BInstructV0_3JSONSchema() throws { + let chatTemplate = + "{{- bos_token }}\n{%- set user_messages = messages | selectattr('role', 'equalto', 'user') | list %}\n{%- for message in messages %}\n {%- if message['role'] == 'user' %}\n {%- if tools and (message == user_messages[-1]) %}\n {{- ' [AVAILABLE_TOOLS] [' }}\n {%- for tool in tools %}\n\t\t{%- set tool = tool.function %}\n\t\t{{- '{\"type\": \"function\", \"function\": {' }}\n\t\t{%- for key, val in tool|items if key != \"return\" %}\n\t\t {%- if val is string %}\n\t\t\t{{- '\"' + key + '\": \"' + val + '\"' }}\n\t\t {%- else %}\n\t\t\t{{- '\"' + key + '\": ' + val|tojson }}\n\t\t {%- endif %}\n\t\t {%- if not loop.last %}\n\t\t\t{{- \", \" }}\n\t\t {%- endif %}\n\t\t{%- endfor %}\n\t\t{{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- ' [/AVAILABLE_TOOLS]' }}\n {%- endif %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- elif message['role'] == 'assistant' %}\n {%- if message.tool_calls is defined and message.tool_calls|length > 0 %}\n {{- ' [TOOL_CALLS] [' }}\n {%- for tool_call in message.tool_calls %}\n {{- {\"name\": tool_call.function.name, \"arguments\": tool_call.function.arguments, \"id\": tool_call.id}|tojson }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- '] ' }}\n {{- eos_token }}\n \t{%- elif message.content is defined %}\n\t {{- ' ' + message.content + ' ' + eos_token}}\n {%- endif %}\n {%- elif message['role'] == 'tool' %}\n {{- ' [TOOL_RESULTS] ' }}\n {{- '{\"call_id\": \"' + message.tool_call_id + '\", \"content\": ' + message.content|string + '}' }}\n {{- ' [/TOOL_RESULTS] ' }}\n {%- endif %}\n{%- endfor %}\n" + let template = try Template(chatTemplate) + + let toolsJSON = try JSONSerialization.data(withJSONObject: exampleListOfTools, options: [.fragmentsAllowed]) + let toolsJSONString = String(data: toolsJSON, encoding: .utf8)! + + let result = try template.render([ + "messages": [ + [ + "role": "system", + "content": + "You are a bot that responds to weather queries. You should reply with the unit used in the queried location.", + ], + ["role": "user", "content": "Hey, what's the temperature in Paris right now?"], + [ + "role": "assistant", + "tool_calls": [ + [ + "id": "abcdef123", + "type": "function", + "function": [ + "name": "get_current_temperature", + "arguments": ["location": "Paris, France", "unit": "celsius"], + ], + ] + ], + ], + ["role": "tool", "tool_call_id": "abcdef123", "name": "get_current_temperature", "content": "22.0"], + ], + "tools": exampleListOfTools, + "tools_json": toolsJSONString, + "bos_token": "", + "eos_token": "", + ]) + let target = """ + [AVAILABLE_TOOLS] [{"type": "function", "function": {"name": "get_current_temperature", "description": "Get the current temperature at a location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for, in the format \\"City, Country\\""}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "description": "The unit to return the temperature in."}}, "required": ["location", "unit"]}}}, {"type": "function", "function": {"name": "get_current_wind_speed", "description": "Get the current wind speed in km/h at a given location.", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The location to get the temperature for, in the format \\"City, Country\\""}}, "required": ["location"]}}}] [/AVAILABLE_TOOLS] [INST] Hey, what\'s the temperature in Paris right now? [/INST] [TOOL_CALLS] [{"name": "get_current_temperature", "arguments": {"location": "Paris, France", "unit": "celsius"}, "id": "abcdef123"}] [TOOL_RESULTS] {"call_id": "abcdef123", "content": 22.0} [/TOOL_RESULTS] + """ + + if target != result { + print("::: testMistral7BInstructV0_3JSONSchema failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testCISCaiMistral7BInstructV0_3SOTAGGUF() throws { + let chatTemplate = """ + {{ bos_token }}{% set ns = namespace(lastuser=-1, system=false, functions=false) %}{% if tools %}{% for message in messages %}{% if message['role'] == 'user' %}{% set ns.lastuser = loop.index0 %}{% elif message['role'] == 'system' %}{% set ns.system = message['content'] %}{% endif %}{% endfor %}{% set ns.functions = tools|selectattr('type','eq','function')|map(attribute='function')|list|tojson %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}{% if loop.index0 == ns.lastuser and ns.functions %}{{ '[AVAILABLE_TOOLS] ' }}{{ ns.functions }}{{ '[/AVAILABLE_TOOLS]' }}{% endif %}{{ '[INST] ' }}{% if loop.index0 == ns.lastuser and ns.system %}{{ ns.system + ' ' }}{% endif %}{{ message['content'] }}{{ '[/INST]' }}{% elif message['role'] == 'tool' %}{{ '[TOOL_RESULTS] ' }}{{ dict(call_id=message['tool_call_id'], content=message['content'])|tojson }}{{ '[/TOOL_RESULTS]' }}{% elif message['role'] == 'assistant' %}{% if message['tool_calls'] %}{{ '[TOOL_CALLS] [' }}{% for call in message['tool_calls'] %}{% if call['type'] == 'function' %}{{ dict(id=call['id'], name=call['function']['name'], arguments=call['function']['arguments'])|tojson }}{% endif %}{% if not loop.last %}{{ ', ' }}{% endif %}{% endfor %}{{ ']' }}{% else %}{{ message['content'] }}{% endif %}{{ eos_token }}{% endif %}{% endfor %} + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": [ + [ + "role": "user", + "content": "What's the weather like in Oslo and Stockholm?", + ] + ], + "tools": [exampleToolJSONSchemas["get_current_weather"]!], + "bos_token": "", + "eos_token": "", + ]) + let target = + """ + [AVAILABLE_TOOLS] [{"name": "get_current_weather", "description": "Get the current weather in a given location", "parameters": {"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}, "required": ["location"]}}][/AVAILABLE_TOOLS][INST] What's the weather like in Oslo and Stockholm?[/INST] + """ + + if target != result { + print("::: testCISCaiMistral7BInstructV0_3SOTAGGUF failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testNousResearchHermes2ProLlama38BJSONSchema() throws { + let chatTemplate = """ + {%- macro json_to_python_type(json_spec) %}\n{%- set basic_type_map = {\n "string": "str",\n "number": "float",\n "integer": "int",\n "boolean": "bool"\n} %}\n\n{%- if basic_type_map[json_spec.type] is defined %}\n {{- basic_type_map[json_spec.type] }}\n{%- elif json_spec.type == "array" %}\n {{- "list[" + json_to_python_type(json_spec|items) + "]"}}\n{%- elif json_spec.type == "object" %}\n {%- if json_spec.additionalProperties is defined %}\n {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']'}}\n {%- else %}\n {{- "dict" }}\n {%- endif %}\n{%- elif json_spec.type is iterable %}\n {{- "Union[" }}\n {%- for t in json_spec.type %}\n {{- json_to_python_type({"type": t}) }}\n {%- if not loop.last %}\n {{- "," }} \n {%- endif %}\n {%- endfor %}\n {{- "]" }}\n{%- else %}\n {{- "Any" }}\n{%- endif %}\n{%- endmacro %}\n\n\n{{- bos_token }}\n{{- "You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }}\n{%- for tool in tools %}\n {%- if tool.function is defined %}\n {%- set tool = tool.function %}\n {%- endif %}\n {{- '{"type": "function", "function": ' }}\n {{- '{"name": ' + tool.name + '", ' }}\n {{- '"description": "' + tool.name + '(' }}\n {%- for param_name, param_fields in tool.parameters.properties|items %}\n {{- param_name + ": " + json_to_python_type(param_fields) }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- ")" }}\n {%- if tool.return is defined %}\n {{- " -> " + json_to_python_type(tool.return) }}\n {%- endif %}\n {{- " - " + tool.description + "\\n\\n" }}\n {%- for param_name, param_fields in tool.parameters.properties|items %}\n {%- if loop.first %}\n {{- " Args:\\n" }}\n {%- endif %}\n {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }}\n {%- endfor %}\n {%- if tool.return is defined and tool.return.description is defined %}\n {{- "\\n Returns:\\n " + tool.return.description }}\n {%- endif %}\n {{- '"' }}\n {{- ', "parameters": ' }}\n {%- if tool.parameters.properties | length == 0 %}\n {{- "{}" }}\n {%- else %}\n {{- tool.parameters | tojson}}\n {%- endif %}\n {{- "}" }}\n {%- if not loop.last %}\n {{- "\\n" }}\n {%- endif %}\n{%- endfor %}\n{{- " " }}\n{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}\n' }}\n{{- "For each function call return a json object with function name and arguments within XML tags as follows:\n" }}\n{{- "\n" }}\n{{- '{"arguments": , "name": }\n' }}\n{{- '<|im_end|>' }}\n{%- for message in messages %}\n {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == "assistant" %}\n {{- '<|im_start|>' + message.role + '\\n\\n' }}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '{ ' }}\n {%- if tool_call.arguments is defined %}\n {{- '"arguments": ' }}\n {{- tool_call.arguments|tojson }}\n {{- ', '}}\n {%- endif %}\n {{- '"name": "' }}\n {{- tool_call.name }}\n {{- '"}' }}\n {{- '\\n ' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == "tool" %}\n {%- if not message.name is defined %}\n {{- raise_exception("Tool response dicts require a 'name' key indicating the name of the called function!") }}\n {%- endif %}\n {{- '<|im_start|>' + message.role + '\\n\\n' }}\n {{- '{"name": "' }}\n {{- message.name }}\n {{- '", "content": ' }}\n {{- message.content|tojson + '}' }}\n {{- '\\n <|im_end|>\\n' }} \n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": [["role": "user", "content": "Fetch the stock fundamentals data for Tesla (TSLA)"]], + "tools": [ + [ + "type": "function", + "function": [ + "name": "get_stock_fundamentals", + "description": "Get fundamental data for a given stock symbol using yfinance API.", + "parameters": [ + "type": "object", + "properties": ["symbol": ["type": "string", "description": "The stock symbol."]], + "required": ["symbol"], + ], + "return": [ + "type": "object", + "description": + "A dictionary containing fundamental data.\n\nKeys:\n - 'symbol': The stock symbol.\n - 'company_name': The long name of the company.\n - 'sector': The sector to which the company belongs.\n - 'industry': The industry to which the company belongs.\n - 'market_cap': The market capitalization of the company.\n - 'pe_ratio': The forward price-to-earnings ratio.\n - 'pb_ratio': The price-to-book ratio.\n - 'dividend_yield': The dividend yield.\n - 'eps': The trailing earnings per share.\n - 'beta': The beta value of the stock.\n - '52_week_high': The 52-week high price of the stock.\n - '52_week_low': The 52-week low price of the stock.", + ], + ], + ] as [String: Any] + ], + "bos_token": "<|begin_of_text|>", + "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|begin_of_text|>You are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: {"type": "function", "function": {"name": get_stock_fundamentals", "description": "get_stock_fundamentals(symbol: str) -> dict - Get fundamental data for a given stock symbol using yfinance API.\n\n Args:\n symbol(str): The stock symbol.\n Returns:\n A dictionary containing fundamental data.\n\nKeys:\n - 'symbol': The stock symbol.\n - 'company_name': The long name of the company.\n - 'sector': The sector to which the company belongs.\n - 'industry': The industry to which the company belongs.\n - 'market_cap': The market capitalization of the company.\n - 'pe_ratio': The forward price-to-earnings ratio.\n - 'pb_ratio': The price-to-book ratio.\n - 'dividend_yield': The dividend yield.\n - 'eps': The trailing earnings per share.\n - 'beta': The beta value of the stock.\n - '52_week_high': The 52-week high price of the stock.\n - '52_week_low': The 52-week low price of the stock.", "parameters": {"type": "object", "properties": {"symbol": {"type": "string", "description": "The stock symbol."}}, "required": ["symbol"]}} Use the following pydantic model json schema for each tool call you will make: {"properties": {"arguments": {"title": "Arguments", "type": "object"}, "name": {"title": "Name", "type": "string"}}, "required": ["arguments", "name"], "title": "FunctionCall", "type": "object"}\nFor each function call return a json object with function name and arguments within XML tags as follows:\n\n{"arguments": , "name": }\n<|im_end|><|im_start|>user\nFetch the stock fundamentals data for Tesla (TSLA)<|im_end|>\n<|im_start|>assistant\n + """ + + if target != result { + print("::: testNousResearchHermes2ProLlama38BJSONSchema failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } + + func testMistralNemoInstruct2407() throws { + let chatTemplate = + "{%- if messages[0][\"role\"] == \"system\" %}\n {%- set system_message = messages[0][\"content\"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n{%- set user_messages = loop_messages | selectattr(\"role\", \"equalto\", \"user\") | list %}\n\n{%- for message in loop_messages | rejectattr(\"role\", \"equalto\", \"tool\") | rejectattr(\"role\", \"equalto\", \"tool_results\") | selectattr(\"tool_calls\", \"undefined\") %}\n {%- if (message[\"role\"] == \"user\") != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\"After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif %}\n{%- endfor %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if message[\"role\"] == \"user\" %}\n {%- if tools is not none and (message == user_messages[-1]) %}\n {{- \"[AVAILABLE_TOOLS][\" }}\n {%- for tool in tools %}\n {%- set tool = tool.function %}\n {{- '{\"type\": \"function\", \"function\": {' }}\n {%- for key, val in tool.items() if key != \"return\" %}\n {%- if val is string %}\n {{- '\"' + key + '\": \"' + val + '\"' }}\n {%- else %}\n {{- '\"' + key + '\": ' + val|tojson }}\n {%- endif %}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- endif %}\n {%- endfor %}\n {{- \"}}\" }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" }}\n {%- endif %}\n {%- endfor %}\n {{- \"[/AVAILABLE_TOOLS]\" }}\n {%- endif %}\n {%- if loop.last and system_message is defined %}\n {{- \"[INST]\" + system_message + \"\\n\\n\" + message[\"content\"] + \"[/INST]\" }}\n {%- else %}\n {{- \"[INST]\" + message[\"content\"] + \"[/INST]\" }}\n {%- endif %}\n {%- elif message[\"role\"] == \"tool_calls\" or message.tool_calls is defined %}\n {%- if message.tool_calls is defined %}\n {%- set tool_calls = message.tool_calls %}\n {%- else %}\n {%- set tool_calls = message.content %}\n {%- endif %}\n {{- \"[TOOL_CALLS][\" }}\n {%- for tool_call in tool_calls %}\n {%- set out = tool_call.function|tojson %}\n {{- out[:-1] }}\n {%- if not tool_call.id is defined or tool_call.id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- ', \"id\": \"' + tool_call.id + '\"}' }}\n {%- if not loop.last %}\n {{- \", \" }}\n {%- else %}\n {{- \"]\" + eos_token }}\n {%- endif %}\n {%- endfor %}\n {%- elif message[\"role\"] == \"assistant\" %}\n {{- message[\"content\"] + eos_token}}\n {%- elif message[\"role\"] == \"tool_results\" or message[\"role\"] == \"tool\" %}\n {%- if message.content is defined and message.content.content is defined %}\n {%- set content = message.content.content %}\n {%- else %}\n {%- set content = message.content %}\n {%- endif %}\n {{- '[TOOL_RESULTS]{\"content\": ' + content|string + \", \" }}\n {%- if not message.tool_call_id is defined or message.tool_call_id|length != 9 %}\n {{- raise_exception(\"Tool call IDs should be alphanumeric strings with length 9!\") }}\n {%- endif %}\n {{- '\"call_id\": \"' + message.tool_call_id + '\"}[/TOOL_RESULTS]' }}\n {%- else %}\n {{- raise_exception(\"Only user and assistant roles are supported, with the exception of an initial optional system message!\") }}\n {%- endif %}\n{%- endfor %}\n" + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": exampleChatMessages, + "bos_token": "", + "eos_token": "", + ]) + let target = + "[INST]Hello, how are you?[/INST]I'm doing great. How can I help you today?[INST]I'd like to show off how chat templating works![/INST]" + + XCTAssertEqual(result, target) + } + + func testMetaLlamaLlama3_18BInstruct() throws { + let chatTemplate = """ + {{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- set date_string = "26 Jul 2024" %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = "" %}\n{%- endif %}\n\n{#- System message + builtin tools #}\n{{- "<|start_header_id|>system<|end_header_id|>\\n\\n" }}\n{%- if builtin_tools is defined or tools is not none %}\n {{- "Environment: ipython\\n" }}\n{%- endif %}\n{%- if builtin_tools is defined %}\n {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\\n\\n"}}\n{%- endif %}\n{{- "Cutting Knowledge Date: December 2023\\n" }}\n{{- "Today Date: " + date_string + "\\n\\n" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}\n {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- "<|eot_id|>" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- "Given the following functions, please respond with a JSON for a function call " }}\n {{- "with its proper arguments that best answers the given prompt.\\n\\n" }}\n {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}\n {{- "Do not use variables.\\n\\n" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- "\\n\\n" }}\n {%- endfor %}\n {{- first_user_message + "<|eot_id|>"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception("This model only supports single tool-calls at once!") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {%- if builtin_tools is defined and tool_call.name in builtin_tools %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- "<|python_tag|>" + tool_call.name + ".call(" }}\n {%- for arg_name, arg_val in tool_call.arguments | items %}\n {{- arg_name + '="' + arg_val + '"' }}\n {%- if not loop.last %}\n {{- ", " }}\n {%- endif %}\n {%- endfor %}\n {{- ")" }}\n {%- else %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{"name": "' + tool_call.name + '", ' }}\n {{- '"parameters": ' }}\n {{- tool_call.arguments | tojson }}\n {{- "}" }}\n {%- endif %}\n {%- if builtin_tools is defined %}\n {#- This means we're in ipython mode #}\n {{- "<|eom_id|>" }}\n {%- else %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n {%- elif message.role == "tool" or message.role == "ipython" %}\n {{- "<|start_header_id|>ipython<|end_header_id|>\\n\\n" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- "<|eot_id|>" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n + """ + let template = try Template(chatTemplate) + let result = try template.render([ + "messages": [ + ["role": "system", "content": "You are a bot that responds to weather queries."], + ["role": "user", "content": "Hey, what's the temperature in Paris right now?"], + ], + "tools": [exampleToolJSONSchemas["get_current_temperature_v1"]!], + "bos_token": "<|begin_of_text|>", + "eos_token": "<|im_end|>", + "add_generation_prompt": true, + ]) + let target = """ + <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nEnvironment: ipython\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\nYou are a bot that responds to weather queries.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGiven the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.\n\nRespond in the format {"name": function name, "parameters": dictionary of argument name and its value}.Do not use variables.\n\n{\n "type": "function",\n "function": {\n "name": "get_current_temperature",\n "description": "Get the current temperature at a location.",\n "parameters": {\n "type": "object",\n "properties": {\n "location": {\n "type": "string",\n "description": "The location to get the temperature for, in the format \\"City, Country\\""\n }\n },\n "required": [\n "location"\n ]\n },\n "return": {\n "type": "number",\n "description": "The current temperature at the specified location in the specified units, as a float."\n }\n }\n}\n\nHey, what's the temperature in Paris right now?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n + """ + + if target != result { + print("::: testMetaLlamaLlama3_18BInstruct failed.") + print("::: target:") + print(target) + print("::: result:") + print(result) + } + XCTAssertEqual(result, target) + } +} + +extension Data { + var string: String? { + return String(data: self, encoding: .utf8) + } +}