Skip to content

Commit

Permalink
[Vertex AI] Add ImagenModel with generateImages functions (#14226)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Dec 7, 2024
1 parent 4606247 commit 57db276
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 6 deletions.
15 changes: 14 additions & 1 deletion FirebaseVertexAI/Sources/GenerationConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,17 @@ public struct GenerationConfig {
// MARK: - Codable Conformances

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
extension GenerationConfig: Encodable {}
extension GenerationConfig: Encodable {
enum CodingKeys: String, CodingKey {
case temperature
case topP
case topK
case candidateCount
case maxOutputTokens
case presencePenalty
case frequencyPenalty
case stopSequences
case responseMIMEType = "responseMimeType"
case responseSchema
}
}
1 change: 0 additions & 1 deletion FirebaseVertexAI/Sources/GenerativeAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ struct GenerativeAIService {
}

let encoder = JSONEncoder()
encoder.keyEncodingStrategy = .convertToSnakeCase
urlRequest.httpBody = try encoder.encode(request)
urlRequest.timeoutInterval = request.options.timeout

Expand Down
5 changes: 5 additions & 0 deletions FirebaseVertexAI/Sources/Types/Internal/InternalPart.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ struct FileData: Codable, Equatable, Sendable {
self.fileURI = fileURI
self.mimeType = mimeType
}

enum CodingKeys: String, CodingKey {
case fileURI = "fileUri"
case mimeType
}
}

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ extension ImageGenerationResponse: Decodable where ImageType: Decodable {
guard container.contains(.predictions) else {
images = []
raiFilteredReason = nil
// TODO: Log warning if no predictions.
// TODO(#14221): Log warning if no predictions.
return
}
var predictionsContainer = try container.nestedUnkeyedContainer(forKey: .predictions)
Expand Down
92 changes: 92 additions & 0 deletions FirebaseVertexAI/Sources/Types/Public/Imagen/ImagenModel.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import FirebaseAppCheckInterop
import FirebaseAuthInterop
import Foundation

@available(iOS 15.0, macOS 12.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *)
public final class ImagenModel {
/// The resource name of the model in the backend; has the format "models/model-name".
let modelResourceName: String

/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

/// Configuration parameters for sending requests to the backend.
let requestOptions: RequestOptions

init(name: String,
projectID: String,
apiKey: String,
requestOptions: RequestOptions,
appCheck: AppCheckInterop?,
auth: AuthInterop?,
urlSession: URLSession = .shared) {
modelResourceName = name
generativeAIService = GenerativeAIService(
projectID: projectID,
apiKey: apiKey,
appCheck: appCheck,
auth: auth,
urlSession: urlSession
)
self.requestOptions = requestOptions
}

public func generateImages(prompt: String) async throws
-> ImageGenerationResponse<ImagenInlineDataImage> {
return try await generateImages(
prompt: prompt,
parameters: imageGenerationParameters(storageURI: nil)
)
}

public func generateImages(prompt: String, storageURI: String) async throws
-> ImageGenerationResponse<ImagenFileDataImage> {
return try await generateImages(
prompt: prompt,
parameters: imageGenerationParameters(storageURI: storageURI)
)
}

func generateImages<T: Decodable>(prompt: String,
parameters: ImageGenerationParameters) async throws
-> ImageGenerationResponse<T> {
let request = ImageGenerationRequest<T>(
model: modelResourceName,
options: requestOptions,
instances: [ImageGenerationInstance(prompt: prompt)],
parameters: parameters
)

return try await generativeAIService.loadRequest(request: request)
}

func imageGenerationParameters(storageURI: String?) -> ImageGenerationParameters {
// TODO(#14221): Add support for configuring these parameters.
return ImageGenerationParameters(
sampleCount: 1,
storageURI: storageURI,
seed: nil,
negativePrompt: nil,
aspectRatio: nil,
safetyFilterLevel: nil,
personGeneration: nil,
outputOptions: nil,
addWatermark: nil,
includeResponsibleAIFilterReason: true
)
}
}
12 changes: 12 additions & 0 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ public class VertexAI {
)
}

public func imagenModel(modelName: String, requestOptions: RequestOptions = RequestOptions())
-> ImagenModel {
return ImagenModel(
name: modelResourceName(modelName: modelName),
projectID: projectID,
apiKey: apiKey,
requestOptions: requestOptions,
appCheck: appCheck,
auth: auth
)
}

/// Class to enable VertexAI to register via the Objective-C based Firebase component system
/// to include VertexAI in the userAgent.
@objc(FIRVertexAIComponent) class FirebaseVertexAIComponent: NSObject {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ final class IntegrationTests: XCTestCase {

var vertex: VertexAI!
var model: GenerativeModel!
var imagenModel: ImagenModel!
var storage: Storage!
var userID1 = ""

Expand All @@ -60,6 +61,9 @@ final class IntegrationTests: XCTestCase {
toolConfig: .init(functionCallingConfig: .none()),
systemInstruction: systemInstruction
)
imagenModel = vertex.imagenModel(
modelName: "imagen-3.0-fast-generate-001"
)

storage = Storage.storage()
}
Expand Down Expand Up @@ -235,6 +239,30 @@ final class IntegrationTests: XCTestCase {
XCTAssertTrue(String(describing: error).contains("Firebase App Check token is invalid"))
}
}

// MARK: - Imagen

func testGenerateImage_inlineData() async throws {
let imagePrompt = """
A realistic photo of a male lion, mane thick and dark, standing proudly on a rocky outcrop
overlooking a vast African savanna at sunset. Golden hour light, long shadows, sharp focus on
the lion, shallow depth of field, detailed fur texture, DSLR, 85mm lens.
"""

let imageResponse = try await imagenModel.generateImages(prompt: imagePrompt)

XCTAssertNil(imageResponse.raiFilteredReason)
XCTAssertEqual(imageResponse.images.count, 1)
let image = try XCTUnwrap(imageResponse.images.first)

let textResponse = try await model.generateContent(
InlineDataPart(data: image.data, mimeType: "image/png"),
"What is the name of this animal? Answer with the animal name only."
)

let text = try XCTUnwrap(textResponse.text).trimmingCharacters(in: .whitespacesAndNewlines)
XCTAssertEqual(text, "Lion")
}
}

extension StorageReference {
Expand Down
4 changes: 2 additions & 2 deletions FirebaseVertexAI/Tests/Unit/GenerationConfigTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ final class GenerationConfigTests: XCTestCase {
"frequencyPenalty" : \(frequencyPenalty),
"maxOutputTokens" : \(maxOutputTokens),
"presencePenalty" : \(presencePenalty),
"responseMIMEType" : "\(responseMIMEType)",
"responseMimeType" : "\(responseMIMEType)",
"responseSchema" : {
"items" : {
"nullable" : false,
Expand Down Expand Up @@ -109,7 +109,7 @@ final class GenerationConfigTests: XCTestCase {
let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8))
XCTAssertEqual(json, """
{
"responseMIMEType" : "\(mimeType)",
"responseMimeType" : "\(mimeType)",
"responseSchema" : {
"nullable" : false,
"properties" : {
Expand Down
2 changes: 1 addition & 1 deletion FirebaseVertexAI/Tests/Unit/PartTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ final class PartTests: XCTestCase {
XCTAssertEqual(json, """
{
"fileData" : {
"fileURI" : "\(fileURI)",
"fileUri" : "\(fileURI)",
"mimeType" : "\(mimeType)"
}
}
Expand Down

0 comments on commit 57db276

Please sign in to comment.