feat: propagate download cancel to extension HTTP requests and fix SAF filename extension mismatch

- Bind cancel context to all extension HTTP calls (fetch, httpGet, httpPost,
  httpRequest, fileDownload, authExchangeCodeWithPKCE) so in-flight requests
  are aborted when user cancels a download
- Make initDownloadCancel idempotent: return existing context if entry already
  exists and preserve pre-cancelled state
- Force SAF output filename to match actual file extension when extension
  returns a different format than requested (e.g. FLAC requested but M4A produced)
- Map ALAC/AAC quality to .m4a instead of falling through to default .flac
This commit is contained in:
zarzet 2026-04-13 22:20:17 +07:00
parent 7405855e01
commit dbba4d6630
10 changed files with 141 additions and 3 deletions

View file

@ -308,6 +308,21 @@ class MainActivity: FlutterFragmentActivity() {
}
}
private fun forceFilenameExt(name: String, outputExt: String): String {
val normalizedExt = normalizeExt(outputExt)
if (normalizedExt.isBlank()) return sanitizeFilename(name)
val safeName = sanitizeFilename(name)
val lower = safeName.lowercase(Locale.ROOT)
val knownExts = listOf(".flac", ".m4a", ".mp3", ".opus", ".lrc")
for (knownExt in knownExts) {
if (lower.endsWith(knownExt)) {
return safeName.dropLast(knownExt.length) + normalizedExt
}
}
return safeName + normalizedExt
}
private fun sanitizeFilename(name: String): String {
var sanitized = name
.replace("/", " ")
@ -617,12 +632,12 @@ class MainActivity: FlutterFragmentActivity() {
private fun buildSafFileName(req: JSONObject, outputExt: String): String {
val provided = req.optString("saf_file_name", "")
if (provided.isNotBlank()) return sanitizeFilename(provided)
if (provided.isNotBlank()) return forceFilenameExt(provided, outputExt)
val trackName = req.optString("track_name", "track")
val artistName = req.optString("artist_name", "")
val baseName = if (artistName.isNotBlank()) "$artistName - $trackName" else trackName
return sanitizeFilename(baseName) + outputExt
return forceFilenameExt(baseName, outputExt)
}
private fun errorJson(message: String): String {
@ -937,7 +952,7 @@ class MainActivity: FlutterFragmentActivity() {
?: return errorJson("Failed to access SAF directory")
val existingFile = targetDir.findFile(fileName)
val document = existingFile ?: targetDir.createFile(mimeType, fileName)
var document = existingFile ?: targetDir.createFile(mimeType, fileName)
?: return errorJson("Failed to create SAF file")
val pfd = contentResolver.openFileDescriptor(document.uri, "rw")
@ -965,6 +980,18 @@ class MainActivity: FlutterFragmentActivity() {
if (!srcFile.exists() || srcFile.length() <= 0) {
throw IllegalStateException("extension output missing or empty: $goFilePath")
}
val actualExt = normalizeExt(srcFile.extension)
if (actualExt.isNotBlank() && actualExt != outputExt) {
val actualFileName = buildSafFileName(req, actualExt)
val actualMimeType = mimeTypeForExt(actualExt)
val replacement = targetDir.findFile(actualFileName)
?: targetDir.createFile(actualMimeType, actualFileName)
?: throw IllegalStateException("failed to create SAF output with actual extension")
if (replacement.uri != document.uri) {
document.delete()
document = replacement
}
}
contentResolver.openOutputStream(document.uri, "wt")?.use { output ->
srcFile.inputStream().use { input ->
input.copyTo(output)

View file

@ -10,6 +10,7 @@ import (
var ErrDownloadCancelled = errors.New("download cancelled")
type cancelEntry struct {
ctx context.Context
cancel context.CancelFunc
canceled bool
}
@ -27,8 +28,21 @@ func initDownloadCancel(itemID string) context.Context {
cancelMu.Lock()
defer cancelMu.Unlock()
if entry, ok := cancelMap[itemID]; ok {
if entry.ctx == nil {
ctx, cancel := context.WithCancel(context.Background())
entry.ctx = ctx
entry.cancel = cancel
if entry.canceled && entry.cancel != nil {
entry.cancel()
}
}
return entry.ctx
}
ctx, cancel := context.WithCancel(context.Background())
cancelMap[itemID] = &cancelEntry{
ctx: ctx,
cancel: cancel,
canceled: false,
}

View file

@ -615,6 +615,10 @@ func (p *extensionProviderWrapper) Download(trackID, quality, outputPath, itemID
p.extension.runtime.setActiveDownloadItemID(itemID)
defer p.extension.runtime.clearActiveDownloadItemID()
}
if itemID != "" {
initDownloadCancel(itemID)
defer clearDownloadCancel(itemID)
}
p.vm.Set("__onProgress", func(call goja.FunctionCall) goja.Value {
if len(call.Arguments) > 0 {

View file

@ -160,6 +160,19 @@ func (r *extensionRuntime) getActiveDownloadItemID() string {
return r.activeDownloadItemID
}
func (r *extensionRuntime) bindDownloadCancelContext(req *http.Request) *http.Request {
if req == nil {
return nil
}
itemID := r.getActiveDownloadItemID()
if itemID == "" {
return req
}
return req.WithContext(initDownloadCancel(itemID))
}
func newExtensionHTTPClient(ext *loadedExtension, jar http.CookieJar, timeout time.Duration) *http.Client {
// Extension sandbox enforces HTTPS-only domains. Do not apply global
// allow_http scheme downgrade here, because some extension APIs (e.g.

View file

@ -458,6 +458,7 @@ func (r *extensionRuntime) authExchangeCodeWithPKCE(call goja.FunctionCall) goja
"error": err.Error(),
})
}
req = r.bindDownloadCancelContext(req)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("User-Agent", "SpotiFLAC-Extension/1.0")

View file

@ -166,6 +166,7 @@ func (r *extensionRuntime) fileDownload(call goja.FunctionCall) goja.Value {
"error": err.Error(),
})
}
req = r.bindDownloadCancelContext(req)
for k, v := range headers {
req.Header.Set(k, v)

View file

@ -81,6 +81,7 @@ func (r *extensionRuntime) httpGet(call goja.FunctionCall) goja.Value {
"error": err.Error(),
})
}
req = r.bindDownloadCancelContext(req)
for k, v := range headers {
req.Header.Set(k, v)
@ -175,6 +176,7 @@ func (r *extensionRuntime) httpPost(call goja.FunctionCall) goja.Value {
"error": err.Error(),
})
}
req = r.bindDownloadCancelContext(req)
for k, v := range headers {
req.Header.Set(k, v)
@ -284,6 +286,7 @@ func (r *extensionRuntime) httpRequest(call goja.FunctionCall) goja.Value {
"error": err.Error(),
})
}
req = r.bindDownloadCancelContext(req)
for k, v := range headers {
req.Header.Set(k, v)
@ -410,6 +413,7 @@ func (r *extensionRuntime) httpMethodShortcut(method string, call goja.FunctionC
"error": err.Error(),
})
}
req = r.bindDownloadCancelContext(req)
for k, v := range headers {
req.Header.Set(k, v)

View file

@ -69,6 +69,7 @@ func (r *extensionRuntime) fetchPolyfill(call goja.FunctionCall) goja.Value {
if err != nil {
return r.createFetchError(err.Error())
}
req = r.bindDownloadCancelContext(req)
for k, v := range headers {
req.Header.Set(k, v)

View file

@ -1,8 +1,10 @@
package gobackend
import (
"net/http"
"path/filepath"
"testing"
"time"
"github.com/dop251/goja"
)
@ -290,6 +292,76 @@ func TestExtensionRuntime_UtilityFunctions(t *testing.T) {
}
}
func TestExtensionRuntime_BindDownloadCancelContext(t *testing.T) {
ext := &loadedExtension{
ID: "test-ext",
Manifest: &ExtensionManifest{
Name: "test-ext",
},
DataDir: t.TempDir(),
}
runtime := newExtensionRuntime(ext)
runtime.setActiveDownloadItemID("test-item")
t.Cleanup(func() {
clearDownloadCancel("test-item")
runtime.clearActiveDownloadItemID()
})
req, err := http.NewRequest("GET", "https://api.example.com/test", nil)
if err != nil {
t.Fatalf("NewRequest failed: %v", err)
}
req = runtime.bindDownloadCancelContext(req)
cancelDownload("test-item")
select {
case <-req.Context().Done():
case <-time.After(500 * time.Millisecond):
t.Fatal("Expected bound request context to be cancelled")
}
if req.Context().Err() == nil {
t.Fatal("Expected request context error after cancellation")
}
}
func TestExtensionRuntime_BindDownloadCancelContextPreservesPreCancelledState(t *testing.T) {
ext := &loadedExtension{
ID: "test-ext",
Manifest: &ExtensionManifest{
Name: "test-ext",
},
DataDir: t.TempDir(),
}
runtime := newExtensionRuntime(ext)
runtime.setActiveDownloadItemID("test-item")
cancelDownload("test-item")
t.Cleanup(func() {
clearDownloadCancel("test-item")
runtime.clearActiveDownloadItemID()
})
req, err := http.NewRequest("GET", "https://api.example.com/test", nil)
if err != nil {
t.Fatalf("NewRequest failed: %v", err)
}
req = runtime.bindDownloadCancelContext(req)
select {
case <-req.Context().Done():
case <-time.After(500 * time.Millisecond):
t.Fatal("Expected pre-cancelled request context to stay cancelled")
}
if req.Context().Err() == nil {
t.Fatal("Expected request context error for pre-cancelled item")
}
}
func TestExtensionRuntime_SSRFProtection(t *testing.T) {
// Create extension with limited network permissions
ext := &loadedExtension{

View file

@ -2378,6 +2378,7 @@ class DownloadQueueNotifier extends Notifier<DownloadQueueState> {
return '.m4a';
}
final q = quality.toLowerCase();
if (q == 'alac' || q.startsWith('aac')) return '.m4a';
if (q.startsWith('opus')) return '.opus';
if (q.startsWith('mp3')) return '.mp3';
return '.flac';