Files
Mubert-Text-to-Music/Mubert_Text_to_Music.ipynb
Ilya Belikov 6b3dd76741 prune outputs
2022-10-18 08:49:49 +03:00

249 lines
10 KiB
Plaintext

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [],
"toc_visible": true,
"authorship_tag": "ABX9TyNRcUVF3ZzTw+oK4ortpcH+",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/ferluht/Mubert-Text-to-Music/blob/main/Mubert_Text_to_Music.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# **Mubert Text to Music ✍ ➡ 🎹🎵🔊**\n",
"\n",
"A simple notebook demonstrating prompt-based music generation via [Mubert](https://mubert.com) [API](https://mubert2.docs.apiary.io/)"
],
"metadata": {
"id": "gHJPhnu7Lg2v"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "GPdDFKWVVnif"
},
"outputs": [],
"source": [
"#@title **Setup Environment**\n",
"\n",
"import subprocess, time\n",
"print(\"Setting up environment...\")\n",
"start_time = time.time()\n",
"all_process = [\n",
" ['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],\n",
" ['pip', 'install', '-U', 'sentence-transformers'],\n",
" ['pip', 'install', 'httpx'],\n",
"]\n",
"for process in all_process:\n",
" running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
"\n",
"end_time = time.time()\n",
"print(f\"Environment set up in {end_time-start_time:.0f} seconds\")"
]
},
{
"cell_type": "code",
"source": [
"#@title **Define Mubert methods and pre-compute things**\n",
"\n",
"import numpy as np\n",
"from sentence_transformers import SentenceTransformer\n",
"minilm = SentenceTransformer('all-MiniLM-L6-v2')\n",
"\n",
"mubert_tags_string = 'tribal,action,kids,neo-classic,run 130,pumped,jazz / funk,ethnic,dubtechno,reggae,acid jazz,liquidfunk,funk,witch house,tech house,underground,artists,mystical,disco,sensorium,r&b,agender,psychedelic trance / psytrance,peaceful,run 140,piano,run 160,setting,meditation,christmas,ambient,horror,cinematic,electro house,idm,bass,minimal,underscore,drums,glitchy,beautiful,technology,tribal house,country pop,jazz & funk,documentary,space,classical,valentines,chillstep,experimental,trap,new jack swing,drama,post-rock,tense,corporate,neutral,happy,analog,funky,spiritual,sberzvuk special,chill hop,dramatic,catchy,holidays,fitness 90,optimistic,orchestra,acid techno,energizing,romantic,minimal house,breaks,hyper pop,warm up,dreamy,dark,urban,microfunk,dub,nu disco,vogue,keys,hardcore,aggressive,indie,electro funk,beauty,relaxing,trance,pop,hiphop,soft,acoustic,chillrave / ethno-house,deep techno,angry,dance,fun,dubstep,tropical,latin pop,heroic,world music,inspirational,uplifting,atmosphere,art,epic,advertising,chillout,scary,spooky,slow ballad,saxophone,summer,erotic,jazzy,energy 100,kara mar,xmas,atmospheric,indie pop,hip-hop,yoga,reggaeton,lounge,travel,running,folk,chillrave & ethno-house,detective,darkambient,chill,fantasy,minimal techno,special,night,tropical house,downtempo,lullaby,meditative,upbeat,glitch hop,fitness,neurofunk,sexual,indie rock,future pop,jazz,cyberpunk,melancholic,happy hardcore,family / kids,synths,electric guitar,comedy,psychedelic trance & psytrance,edm,psychedelic rock,calm,zen,bells,podcast,melodic house,ethnic percussion,nature,heavy,bassline,indie dance,techno,drumnbass,synth pop,vaporwave,sad,8-bit,chillgressive,deep,orchestral,futuristic,hardtechno,nostalgic,big room,sci-fi,tutorial,joyful,pads,minimal 170,drill,ethnic 108,amusing,sleepy ambient,psychill,italo disco,lofi,house,acoustic guitar,bassline house,rock,k-pop,synthwave,deep house,electronica,gabber,nightlife,sport & fitness,road trip,celebration,electro,disco house,electronic'\n",
"mubert_tags = np.array(mubert_tags_string.split(','))\n",
"mubert_tags_embeddings = minilm.encode(mubert_tags)\n",
"\n",
"from IPython.display import Audio, display\n",
"import httpx\n",
"import json\n",
"\n",
"def get_track_by_tags(tags, pat, duration, maxit=20, autoplay=False, loop=False):\n",
" if loop:\n",
" mode = \"loop\"\n",
" else:\n",
" mode = \"track\"\n",
" r = httpx.post('https://api-b2b.mubert.com/v2/RecordTrackTTM', \n",
" json={\n",
" \"method\":\"RecordTrackTTM\",\n",
" \"params\": {\n",
" \"pat\": pat, \n",
" \"duration\": duration,\n",
" \"tags\": tags,\n",
" \"mode\": mode\n",
" }\n",
" })\n",
"\n",
" rdata = json.loads(r.text)\n",
" assert rdata['status'] == 1, rdata['error']['text']\n",
" trackurl = rdata['data']['tasks'][0]['download_link']\n",
"\n",
" print('Generating track ', end='')\n",
" for i in range(maxit):\n",
" r = httpx.get(trackurl)\n",
" if r.status_code == 200:\n",
" display(Audio(trackurl, autoplay=autoplay))\n",
" break\n",
" time.sleep(1)\n",
" print('.', end='')\n",
"\n",
"def find_similar(em, embeddings, method='cosine'):\n",
" scores = []\n",
" for ref in embeddings:\n",
" if method == 'cosine': \n",
" scores.append(1 - np.dot(ref, em)/(np.linalg.norm(ref)*np.linalg.norm(em)))\n",
" if method == 'norm': \n",
" scores.append(np.linalg.norm(ref - em))\n",
" return np.array(scores), np.argsort(scores)\n",
"\n",
"def get_tags_for_prompts(prompts, top_n=3, debug=False):\n",
" prompts_embeddings = minilm.encode(prompts)\n",
" ret = []\n",
" for i, pe in enumerate(prompts_embeddings):\n",
" scores, idxs = find_similar(pe, mubert_tags_embeddings)\n",
" top_tags = mubert_tags[idxs[:top_n]]\n",
" top_prob = 1 - scores[idxs[:top_n]]\n",
" if debug:\n",
" print(f\"Prompt: {prompts[i]}\\nTags: {', '.join(top_tags)}\\nScores: {top_prob}\\n\\n\\n\")\n",
" ret.append((prompts[i], list(top_tags)))\n",
" return ret"
],
"metadata": {
"cellView": "form",
"id": "yW-3aTNYvKM_"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@markdown **Get personal access token in Mubert and define API methods**\n",
"email = \"your e-mail here\" #@param {type:\"string\"}\n",
"\n",
"r = httpx.post('https://api-b2b.mubert.com/v2/GetServiceAccess', \n",
" json={\n",
" \"method\":\"GetServiceAccess\",\n",
" \"params\": {\n",
" \"email\": email,\n",
" \"license\":\"ttmmubertlicense#f0acYBenRcfeFpNT4wpYGaTQIyDI4mJGv5MfIhBFz97NXDwDNFHmMRsBSzmGsJwbTpP1A6i07AXcIeAHo5\",\n",
" \"token\":\"4951f6428e83172a4f39de05d5b3ab10d58560b8\",\n",
" \"mode\": \"loop\"\n",
" }\n",
" })\n",
"\n",
"rdata = json.loads(r.text)\n",
"assert rdata['status'] == 1, \"probably incorrect e-mail\"\n",
"pat = rdata['data']['pat']\n",
"print(f'Got token: {pat}')"
],
"metadata": {
"cellView": "form",
"id": "a4ACdvWLRJ5U"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title **Generate some music 🎵**\n",
"\n",
"prompt = 'vladimir lenin smoking weed with bob marley' #@param {type:\"string\"}\n",
"duration = 30 #@param {type:\"number\"}\n",
"loop = False #@param {type:\"boolean\"}\n",
"\n",
"def generate_track_by_prompt(prompt, duration, loop=False):\n",
" _, tags = get_tags_for_prompts([prompt,])[0]\n",
" try:\n",
" get_track_by_tags(tags, pat, duration, autoplay=True, loop=loop)\n",
" except Exception as e:\n",
" print(str(e))\n",
" print('\\n')\n",
"\n",
"generate_track_by_prompt(prompt, duration, loop)\n"
],
"metadata": {
"cellView": "form",
"id": "hTf7sZcfbI0K"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### **Batch generation 🎶**"
],
"metadata": {
"id": "wSKTfub-bitp"
}
},
{
"cell_type": "code",
"source": [
"duration = 60\n",
"\n",
"prompts = [\n",
" 'kind beaver guards life tree, stan lee, epic',\n",
" 'astronaut riding a horse',\n",
" 'winnie the pooh cooking methamphetamine',\n",
" 'vladimir lenin smoking weed with bob marley',\n",
" 'soviet retrofuturism',\n",
" 'two wasted friends high on weed are trying to navigate their way to their hostel in a big city, night, trippy',\n",
" 'an elephant levitating on a gas balloon',\n",
" 'calm music',\n",
" 'a refrigerator floating in a pond'\n",
"]\n",
"\n",
"tags = get_tags_for_prompts(prompts)\n",
"\n",
"for i, tag in enumerate(tags):\n",
" print(f'Prompt: {tag[0]}\\nTags: {tag[1]}')\n",
" try:\n",
" get_track_by_tags(tag[1], pat, duration, autoplay=False)\n",
" except Exception as e:\n",
" print(str(e))\n",
" print('\\n')"
],
"metadata": {
"id": "BzrhcwIHXlg0"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "JieLk6kjZFai"
},
"execution_count": null,
"outputs": []
}
]
}