Skip to content

Commit

Permalink
Send a request to config.json when downloading model from HF reposito…
Browse files Browse the repository at this point in the history
…ry (#31)

## Description
Currently, the HuggingFace download count is not incremented when
downloading models. HF counts the model downloads by counting the
requests sent to the config.json file in the repository root (a HEAD
request is enough). This PR sends a HEAD request to the config json,
when the model source is our organization.
source: https://huggingface.co/docs/hub/models-download-stats

### Type of change
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)
- [ ] Documentation update (improves or adds clarity to existing
documentation)

### Tested on
- [x] iOS
- [x] Android

### Testing instructions
<!-- Provide step-by-step instructions on how to test your changes.
Include setup details if necessary. -->

### Screenshots
<!-- Add screenshots here, if applicable -->

### Related issues
<!-- Link related issues here using #issue-number -->

### Checklist
- [x] I have performed a self-review of my code
- [x] I have commented my code, particularly in hard-to-understand areas
- [ ] I have updated the documentation accordingly
- [x] My changes generate no new warnings

### Additional notes
<!-- Include any additional information, assumptions, or context that
reviewers might need to understand this PR. -->
  • Loading branch information
chmjkb authored Nov 22, 2024
1 parent 0ae6023 commit fe00d6b
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 122 deletions.
283 changes: 161 additions & 122 deletions android/src/main/java/com/rnexecutorch/Fetcher.kt
Original file line number Diff line number Diff line change
Expand Up @@ -5,104 +5,135 @@ import okhttp3.Call
import okhttp3.Callback
import okhttp3.OkHttpClient
import okhttp3.Request
import okhttp3.RequestBody
import okhttp3.Response
import okio.IOException
import java.io.File
import java.io.FileOutputStream
import java.net.URL

enum class ResourceType {
TOKENIZER,
MODEL
TOKENIZER,
MODEL,
}

class Fetcher {
companion object {
private fun saveResponseToFile(
response: Response,
directory: File,
fileName: String
): File {
val file = File(directory.path, fileName)
file.outputStream().use { outputStream ->
response.body?.byteStream()?.copyTo(outputStream)
}
return file
companion object {
private fun saveResponseToFile(
response: Response,
directory: File,
fileName: String,
): File {
val file = File(directory.path, fileName)
file.outputStream().use { outputStream ->
response.body?.byteStream()?.copyTo(outputStream)
}
return file
}

private fun hasValidExtension(fileName: String, resourceType: ResourceType): Boolean {
return when (resourceType) {
ResourceType.TOKENIZER -> {
fileName.endsWith(".bin")
}

private fun hasValidExtension(fileName: String, resourceType: ResourceType): Boolean {
return when (resourceType) {
ResourceType.TOKENIZER -> {
fileName.endsWith(".bin")
}
ResourceType.MODEL -> {
fileName.endsWith(".pte")
}
}
}

ResourceType.MODEL -> {
fileName.endsWith(".pte")
}
}
private fun extractFileName(url: URL): String {
if (url.path == "/assets/") {
val pathSegments = url.toString().split('/')
return pathSegments[pathSegments.size - 1].split("?")[0]
} else if (url.protocol == "file") {
val localPath = url.toString().split("://")[1]
val file = File(localPath)
if (file.exists()) {
return localPath
}

private fun extractFileName(url: URL): String {
if (url.path == "/assets/") {
val pathSegments = url.toString().split('/')
return pathSegments[pathSegments.size - 1].split("?")[0]
} else if (url.protocol == "file") {
val localPath = url.toString().split("://")[1]
val file = File(localPath)
if (file.exists()) {
return localPath
throw Exception("file_not_found")
} else {
return url.path.substringAfterLast('/')
}
}

private fun fetchModel(
file: File,
validFile: File,
client: OkHttpClient,
url: URL,
onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null,
) {
val request = Request.Builder().url(url).build()
client.newCall(request).enqueue(object : Callback {
override fun onFailure(call: Call, e: IOException) {
onComplete(null, e)
}

override fun onResponse(call: Call, response: Response) {
if (!response.isSuccessful) {
onComplete(null, Exception("download_error"))
}

response.body?.let { body ->
val progressBody = listener?.let { ProgressResponseBody(body, it) }
val inputStream = progressBody?.source()?.inputStream()
inputStream?.use { input ->
FileOutputStream(file).use { output ->
val buffer = ByteArray(2048)
var bytesRead: Int
while (input.read(buffer).also { bytesRead = it } != -1) {
output.write(buffer, 0, bytesRead)
}
}
}

throw Exception("file_not_found")
if (file.renameTo(validFile)) {
onComplete(validFile.absolutePath, null)
} else {
return url.path.substringAfterLast('/')
onComplete(null, Exception("Failed to move the file to the valid location"))
}
}
}
})
}

private fun fetchModel(file: File, validFile: File, client: OkHttpClient, url: URL, onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null){
val request = Request.Builder().url(url).build()
client.newCall(request).enqueue(object : Callback {
override fun onFailure(call: Call, e: IOException) {
onComplete(null, e)
}
private fun isUrlPointingToHfRepo(url: URL): Boolean {
val expectedHost = "huggingface.co"
val expectedPathPrefix = "/software-mansion/"
if (url.host != expectedHost) {
return false
}
return url.path.startsWith(expectedPathPrefix)
}

override fun onResponse(call: Call, response: Response) {
if (!response.isSuccessful) {
onComplete(null, Exception("download_error"))
}

response.body?.let { body ->
val progressBody = listener?.let { ProgressResponseBody(body, it) }
val inputStream = progressBody?.source()?.inputStream()
inputStream?.use { input ->
FileOutputStream(file).use { output ->
val buffer = ByteArray(2048)
var bytesRead: Int
while (input.read(buffer).also { bytesRead = it } != -1) {
output.write(buffer, 0, bytesRead)
}
}
}

if (file.renameTo(validFile)) {
onComplete(validFile.absolutePath, null)
} else {
onComplete(null, Exception("Failed to move the file to the valid location"))
}
}
}
})
}
private fun resolveConfigUrlFromModelUrl(modelUrl: URL): URL {
// Create a new URL using the base URL and append the desired path
val baseUrl = modelUrl.protocol + "://" + modelUrl.host + modelUrl.path.substringBefore("resolve/")
return URL(baseUrl + "resolve/main/config.json")
}

fun downloadResource(
context: Context,
client: OkHttpClient,
url: URL,
resourceType: ResourceType,
onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null
) {
private fun sendRequestToUrl(url: URL, method: String, body: RequestBody?, client: OkHttpClient): Response {
val request = Request.Builder()
.url(url)
.method(method, body)
.build()
val response = client.newCall(request).execute()
return response
}

fun downloadResource(
context: Context,
client: OkHttpClient,
url: URL,
resourceType: ResourceType,
onComplete: (String?, Exception?) -> Unit,
listener: ProgressResponseBody.ProgressListener? = null,
) {
/*
Fetching model and tokenizer file
1. Extract file name from provided URL
Expand All @@ -115,57 +146,65 @@ class Fetcher {
6. If the file does not exist, and is a tokenizer, fetch the file
7. If the file is a model, fetch the file with ProgressResponseBody
*/
val fileName: String

try {
fileName = extractFileName(url)
} catch (e: Exception) {
onComplete(null, e)
return
}

if(fileName.contains("/")){
onComplete(fileName, null)
return
}

if (!hasValidExtension(fileName, resourceType)) {
onComplete(null, Exception("invalid_resource_extension"))
return
}

var tempFile = File(context.filesDir, fileName)
if(tempFile.exists()){
tempFile.delete()
}
val fileName: String

try {
fileName = extractFileName(url)
} catch (e: Exception) {
onComplete(null, e)
return
}

if (fileName.contains("/")) {
onComplete(fileName, null)
return
}

if (!hasValidExtension(fileName, resourceType)) {
onComplete(null, Exception("invalid_resource_extension"))
return
}

var tempFile = File(context.filesDir, fileName)
if (tempFile.exists()) {
tempFile.delete()
}

val modelsDirectory = File(context.filesDir, "models").apply {
if (!exists()) {
mkdirs()
}
}

val modelsDirectory = File(context.filesDir, "models").apply {
if (!exists()) {
mkdirs()
}
}
var validFile = File(modelsDirectory, fileName)
if (validFile.exists()) {
onComplete(validFile.absolutePath, null)
return
}

var validFile = File(modelsDirectory, fileName)
if (validFile.exists()) {
onComplete(validFile.absolutePath, null)
return
}
if (resourceType == ResourceType.TOKENIZER) {
val request = Request.Builder().url(url).build()
val response = client.newCall(request).execute()

if (resourceType == ResourceType.TOKENIZER) {
val request = Request.Builder().url(url).build()
val response = client.newCall(request).execute()
if (!response.isSuccessful) {
onComplete(null, Exception("download_error"))
return
}

if (!response.isSuccessful) {
onComplete(null, Exception("download_error"))
return
}
validFile = saveResponseToFile(response, modelsDirectory, fileName)
onComplete(validFile.absolutePath, null)
return
}

validFile = saveResponseToFile(response, modelsDirectory, fileName)
onComplete(validFile.absolutePath, null)
return
}
// If the url is a Software Mansion HuggingFace repo, we want to send a HEAD
// request to the config.json file, this increments HF download counter
// https://huggingface.co/docs/hub/models-download-stats
if (isUrlPointingToHfRepo(url)) {
val configUrl = resolveConfigUrlFromModelUrl(url)
sendRequestToUrl(configUrl, "HEAD", null, client)
}

fetchModel(tempFile, validFile, client, url, onComplete, listener)
}
fetchModel(tempFile, validFile, client, url, onComplete, listener)
}
}
}
}
23 changes: 23 additions & 0 deletions ios/utils/LargeFileFetcher.mm
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ - (void)URLSession:(NSURLSession *)session downloadTask:(NSURLSessionDownloadTas
}
}

- (void)sendHeadRequestToURL:(NSURL *)url {
NSMutableURLRequest *request = [NSMutableURLRequest requestWithURL:url];
[request setHTTPMethod:@"HEAD"];
NSURLSessionDataTask *dataTask = [_session dataTaskWithRequest:request];
[dataTask resume];
}

- (void)startDownloadingFileFromURL:(NSURL *)url {
//Check if file is a valid url, if not check if it's path to local file
if (![Fetcher isValidURL:url]) {
Expand All @@ -77,6 +84,22 @@ - (void)startDownloadingFileFromURL:(NSURL *)url {
[self executeCompletionWithSuccess:filePath];
return;
}

// If the url is a Software Mansion HuggingFace repo, we want to send a HEAD
// request to the config.json file, this increments HF download counter
// https://huggingface.co/docs/hub/models-download-stats
NSString *huggingFaceOrgNSString = @"https://huggingface.co/software-mansion/";
NSString *modelURLNSString = [url absoluteString];

if ([modelURLNSString hasPrefix:huggingFaceOrgNSString]) {
NSRange resolveRange = [modelURLNSString rangeOfString:@"resolve"];
if (resolveRange.location != NSNotFound) {
NSString *configURLNSString = [modelURLNSString substringToIndex:resolveRange.location + resolveRange.length];
configURLNSString = [configURLNSString stringByAppendingString:@"/main/config.json"];
NSURL *configNSURL = [NSURL URLWithString:configURLNSString];
[self sendHeadRequestToURL:configNSURL];
}
}

//Cancel all running background download tasks and start new one
_destination = filePath;
Expand Down

0 comments on commit fe00d6b

Please sign in to comment.