mirror of
https://github.com/spotiflacapp/SpotiFLAC-Mobile.git
synced 2026-06-01 03:15:17 +07:00
feat: propagate download cancel to extension HTTP requests and fix SAF filename extension mismatch
- Bind cancel context to all extension HTTP calls (fetch, httpGet, httpPost, httpRequest, fileDownload, authExchangeCodeWithPKCE) so in-flight requests are aborted when user cancels a download - Make initDownloadCancel idempotent: return existing context if entry already exists and preserve pre-cancelled state - Force SAF output filename to match actual file extension when extension returns a different format than requested (e.g. FLAC requested but M4A produced) - Map ALAC/AAC quality to .m4a instead of falling through to default .flac
This commit is contained in:
parent
7405855e01
commit
dbba4d6630
10 changed files with 141 additions and 3 deletions
|
|
@ -308,6 +308,21 @@ class MainActivity: FlutterFragmentActivity() {
|
|||
}
|
||||
}
|
||||
|
||||
private fun forceFilenameExt(name: String, outputExt: String): String {
|
||||
val normalizedExt = normalizeExt(outputExt)
|
||||
if (normalizedExt.isBlank()) return sanitizeFilename(name)
|
||||
|
||||
val safeName = sanitizeFilename(name)
|
||||
val lower = safeName.lowercase(Locale.ROOT)
|
||||
val knownExts = listOf(".flac", ".m4a", ".mp3", ".opus", ".lrc")
|
||||
for (knownExt in knownExts) {
|
||||
if (lower.endsWith(knownExt)) {
|
||||
return safeName.dropLast(knownExt.length) + normalizedExt
|
||||
}
|
||||
}
|
||||
return safeName + normalizedExt
|
||||
}
|
||||
|
||||
private fun sanitizeFilename(name: String): String {
|
||||
var sanitized = name
|
||||
.replace("/", " ")
|
||||
|
|
@ -617,12 +632,12 @@ class MainActivity: FlutterFragmentActivity() {
|
|||
|
||||
private fun buildSafFileName(req: JSONObject, outputExt: String): String {
|
||||
val provided = req.optString("saf_file_name", "")
|
||||
if (provided.isNotBlank()) return sanitizeFilename(provided)
|
||||
if (provided.isNotBlank()) return forceFilenameExt(provided, outputExt)
|
||||
|
||||
val trackName = req.optString("track_name", "track")
|
||||
val artistName = req.optString("artist_name", "")
|
||||
val baseName = if (artistName.isNotBlank()) "$artistName - $trackName" else trackName
|
||||
return sanitizeFilename(baseName) + outputExt
|
||||
return forceFilenameExt(baseName, outputExt)
|
||||
}
|
||||
|
||||
private fun errorJson(message: String): String {
|
||||
|
|
@ -937,7 +952,7 @@ class MainActivity: FlutterFragmentActivity() {
|
|||
?: return errorJson("Failed to access SAF directory")
|
||||
|
||||
val existingFile = targetDir.findFile(fileName)
|
||||
val document = existingFile ?: targetDir.createFile(mimeType, fileName)
|
||||
var document = existingFile ?: targetDir.createFile(mimeType, fileName)
|
||||
?: return errorJson("Failed to create SAF file")
|
||||
|
||||
val pfd = contentResolver.openFileDescriptor(document.uri, "rw")
|
||||
|
|
@ -965,6 +980,18 @@ class MainActivity: FlutterFragmentActivity() {
|
|||
if (!srcFile.exists() || srcFile.length() <= 0) {
|
||||
throw IllegalStateException("extension output missing or empty: $goFilePath")
|
||||
}
|
||||
val actualExt = normalizeExt(srcFile.extension)
|
||||
if (actualExt.isNotBlank() && actualExt != outputExt) {
|
||||
val actualFileName = buildSafFileName(req, actualExt)
|
||||
val actualMimeType = mimeTypeForExt(actualExt)
|
||||
val replacement = targetDir.findFile(actualFileName)
|
||||
?: targetDir.createFile(actualMimeType, actualFileName)
|
||||
?: throw IllegalStateException("failed to create SAF output with actual extension")
|
||||
if (replacement.uri != document.uri) {
|
||||
document.delete()
|
||||
document = replacement
|
||||
}
|
||||
}
|
||||
contentResolver.openOutputStream(document.uri, "wt")?.use { output ->
|
||||
srcFile.inputStream().use { input ->
|
||||
input.copyTo(output)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
var ErrDownloadCancelled = errors.New("download cancelled")
|
||||
|
||||
type cancelEntry struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
canceled bool
|
||||
}
|
||||
|
|
@ -27,8 +28,21 @@ func initDownloadCancel(itemID string) context.Context {
|
|||
cancelMu.Lock()
|
||||
defer cancelMu.Unlock()
|
||||
|
||||
if entry, ok := cancelMap[itemID]; ok {
|
||||
if entry.ctx == nil {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
entry.ctx = ctx
|
||||
entry.cancel = cancel
|
||||
if entry.canceled && entry.cancel != nil {
|
||||
entry.cancel()
|
||||
}
|
||||
}
|
||||
return entry.ctx
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancelMap[itemID] = &cancelEntry{
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
canceled: false,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -615,6 +615,10 @@ func (p *extensionProviderWrapper) Download(trackID, quality, outputPath, itemID
|
|||
p.extension.runtime.setActiveDownloadItemID(itemID)
|
||||
defer p.extension.runtime.clearActiveDownloadItemID()
|
||||
}
|
||||
if itemID != "" {
|
||||
initDownloadCancel(itemID)
|
||||
defer clearDownloadCancel(itemID)
|
||||
}
|
||||
|
||||
p.vm.Set("__onProgress", func(call goja.FunctionCall) goja.Value {
|
||||
if len(call.Arguments) > 0 {
|
||||
|
|
|
|||
|
|
@ -160,6 +160,19 @@ func (r *extensionRuntime) getActiveDownloadItemID() string {
|
|||
return r.activeDownloadItemID
|
||||
}
|
||||
|
||||
func (r *extensionRuntime) bindDownloadCancelContext(req *http.Request) *http.Request {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
itemID := r.getActiveDownloadItemID()
|
||||
if itemID == "" {
|
||||
return req
|
||||
}
|
||||
|
||||
return req.WithContext(initDownloadCancel(itemID))
|
||||
}
|
||||
|
||||
func newExtensionHTTPClient(ext *loadedExtension, jar http.CookieJar, timeout time.Duration) *http.Client {
|
||||
// Extension sandbox enforces HTTPS-only domains. Do not apply global
|
||||
// allow_http scheme downgrade here, because some extension APIs (e.g.
|
||||
|
|
|
|||
|
|
@ -458,6 +458,7 @@ func (r *extensionRuntime) authExchangeCodeWithPKCE(call goja.FunctionCall) goja
|
|||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
req = r.bindDownloadCancelContext(req)
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("User-Agent", "SpotiFLAC-Extension/1.0")
|
||||
|
|
|
|||
|
|
@ -166,6 +166,7 @@ func (r *extensionRuntime) fileDownload(call goja.FunctionCall) goja.Value {
|
|||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
req = r.bindDownloadCancelContext(req)
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ func (r *extensionRuntime) httpGet(call goja.FunctionCall) goja.Value {
|
|||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
req = r.bindDownloadCancelContext(req)
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
|
|
@ -175,6 +176,7 @@ func (r *extensionRuntime) httpPost(call goja.FunctionCall) goja.Value {
|
|||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
req = r.bindDownloadCancelContext(req)
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
|
|
@ -284,6 +286,7 @@ func (r *extensionRuntime) httpRequest(call goja.FunctionCall) goja.Value {
|
|||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
req = r.bindDownloadCancelContext(req)
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
|
|
@ -410,6 +413,7 @@ func (r *extensionRuntime) httpMethodShortcut(method string, call goja.FunctionC
|
|||
"error": err.Error(),
|
||||
})
|
||||
}
|
||||
req = r.bindDownloadCancelContext(req)
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ func (r *extensionRuntime) fetchPolyfill(call goja.FunctionCall) goja.Value {
|
|||
if err != nil {
|
||||
return r.createFetchError(err.Error())
|
||||
}
|
||||
req = r.bindDownloadCancelContext(req)
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
package gobackend
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/dop251/goja"
|
||||
)
|
||||
|
|
@ -290,6 +292,76 @@ func TestExtensionRuntime_UtilityFunctions(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestExtensionRuntime_BindDownloadCancelContext(t *testing.T) {
|
||||
ext := &loadedExtension{
|
||||
ID: "test-ext",
|
||||
Manifest: &ExtensionManifest{
|
||||
Name: "test-ext",
|
||||
},
|
||||
DataDir: t.TempDir(),
|
||||
}
|
||||
|
||||
runtime := newExtensionRuntime(ext)
|
||||
runtime.setActiveDownloadItemID("test-item")
|
||||
t.Cleanup(func() {
|
||||
clearDownloadCancel("test-item")
|
||||
runtime.clearActiveDownloadItemID()
|
||||
})
|
||||
|
||||
req, err := http.NewRequest("GET", "https://api.example.com/test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest failed: %v", err)
|
||||
}
|
||||
|
||||
req = runtime.bindDownloadCancelContext(req)
|
||||
cancelDownload("test-item")
|
||||
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("Expected bound request context to be cancelled")
|
||||
}
|
||||
|
||||
if req.Context().Err() == nil {
|
||||
t.Fatal("Expected request context error after cancellation")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionRuntime_BindDownloadCancelContextPreservesPreCancelledState(t *testing.T) {
|
||||
ext := &loadedExtension{
|
||||
ID: "test-ext",
|
||||
Manifest: &ExtensionManifest{
|
||||
Name: "test-ext",
|
||||
},
|
||||
DataDir: t.TempDir(),
|
||||
}
|
||||
|
||||
runtime := newExtensionRuntime(ext)
|
||||
runtime.setActiveDownloadItemID("test-item")
|
||||
cancelDownload("test-item")
|
||||
t.Cleanup(func() {
|
||||
clearDownloadCancel("test-item")
|
||||
runtime.clearActiveDownloadItemID()
|
||||
})
|
||||
|
||||
req, err := http.NewRequest("GET", "https://api.example.com/test", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest failed: %v", err)
|
||||
}
|
||||
|
||||
req = runtime.bindDownloadCancelContext(req)
|
||||
|
||||
select {
|
||||
case <-req.Context().Done():
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Fatal("Expected pre-cancelled request context to stay cancelled")
|
||||
}
|
||||
|
||||
if req.Context().Err() == nil {
|
||||
t.Fatal("Expected request context error for pre-cancelled item")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtensionRuntime_SSRFProtection(t *testing.T) {
|
||||
// Create extension with limited network permissions
|
||||
ext := &loadedExtension{
|
||||
|
|
|
|||
|
|
@ -2378,6 +2378,7 @@ class DownloadQueueNotifier extends Notifier<DownloadQueueState> {
|
|||
return '.m4a';
|
||||
}
|
||||
final q = quality.toLowerCase();
|
||||
if (q == 'alac' || q.startsWith('aac')) return '.m4a';
|
||||
if (q.startsWith('opus')) return '.opus';
|
||||
if (q.startsWith('mp3')) return '.mp3';
|
||||
return '.flac';
|
||||
|
|
|
|||
Loading…
Reference in a new issue