Generatiivinen tiedustelu

Hienosäädä Whisper-malleja Amazon SageMakerissa LoRA:lla | Amazon Web Services

Treffi:

Whisper on automaattinen puheentunnistus (ASR) -malli, joka on koulutettu käyttämällä 680,000 100 tuntia valvottua tietoa verkosta ja joka kattaa useita kieliä ja tehtäviä. Yksi sen rajoituksista on heikko suorituskyky vähän resursseja vaativilla kielillä, kuten marathin kielellä ja dravidilaisilla kielillä, joita voidaan korjata hienosäädöllä. Whisper-mallin hienosäätö on kuitenkin muodostunut huomattavaksi haasteeksi sekä laskentaresurssien että tallennusvaatimusten kannalta. Viidestä kymmeneen Whisper-mallien täydellistä hienosäätöä tarvitaan noin 100 tuntia A40 GPU:ta (4 Gt SXM7) (vaihtelee mallin koon ja malliparametrien mukaan), ja jokainen hienosäädetty tarkistuspiste vaatii noin XNUMX Gt tallennustilaa. Tämä korkeiden laskenta- ja tallennusvaatimusten yhdistelmä voi aiheuttaa merkittäviä esteitä erityisesti rajallisissa resursseissa, mikä tekee mielekkäiden tulosten saavuttamisen usein poikkeuksellisen vaikeaksi.

Low-Rank Adaption, joka tunnetaan myös nimellä LoRA, käyttää ainutlaatuista lähestymistapaa mallin hienosäätöön. Se ylläpitää esiopetetut mallipainot staattisessa tilassa ja lisää koulutettavia järjestyshajotusmatriiseja Transformer-rakenteen jokaiseen kerrokseen. Tämä menetelmä voi vähentää jatkotehtäviin tarvittavien koulutettavien parametrien määrää 10,000 3 kertaa ja vähentää GPU-muistin tarvetta XNUMX kertaa. Mallin laadun osalta LoRA:n on osoitettu vastaavan tai jopa ylittävän perinteisten hienosäätömenetelmien suorituskyvyn huolimatta siitä, että se toimii harvemmilla koulutettavilla parametreilla (katso tulokset alkuperäisestä LoRA paperia). Se tarjoaa myös paremman harjoittelun suorituskyvyn. toisin kuin sovitin menetelmiä, LoRA ei lisää latenssia päättelyn aikana, mikä säilyttää mallin tehokkuuden käyttöönottovaiheen aikana. Whisperin hienosäätö LoRA:lla on osoittanut lupaavia tuloksia. Otetaan esimerkiksi Whisper-Large-v2: 3-ajanjakson käyttäminen 12 tunnin yhteisellä äänidatajoukolla 8 Gt:n muistissa GPU:ssa kestää 6–8 tuntia, joka on 5 kertaa nopeampi kuin täydellinen hienosäätö vertailukelpoisella suorituskyvyllä.

Amazon Sage Maker on ihanteellinen alusta Whisperin LoRA-hienosäädön toteuttamiseen. Amazon SageMakerin avulla voit rakentaa, kouluttaa ja ottaa käyttöön koneoppimismalleja kaikissa käyttötapauksissa täysin hallitun infrastruktuurin, työkalujen ja työnkulkujen avulla. Mallin koulutuksen lisäetuja voivat olla alhaisemmat koulutuskustannukset Managed Spot Trainingin avulla, hajautetut koulutuskirjastot mallien jakamiseksi ja koulutustietojoukot AWS GPU -esiintymien välillä ja lisää. Koulutetut SageMaker-mallit voidaan helposti ottaa käyttöön johtopäätösten tekemiseksi suoraan SageMakerissa. Tässä viestissä esittelemme vaiheittaisen oppaan LoRA-hienosäädön toteuttamiseksi SageMakerissa. Tähän toteutukseen liittyvä lähdekoodi löytyy osoitteesta GitHub.

Valmistele tietojoukko hienosäätöä varten

Käytämme hienosäätötehtävään vähän resursseja käyttävää marathi-kieltä. Käyttämällä Hugging Face -tietojoukot kirjastosta, voit ladata ja jakaa Common Voice -tietojoukon koulutus- ja testaustietojoukoiksi. Katso seuraava koodi:

from datasets import load_dataset, DatasetDict language = "Marathi"
language_abbr = "mr"
task = "transcribe"
dataset_name = "mozilla-foundation/common_voice_11_0" common_voice = DatasetDict()
common_voice["train"] = load_dataset(dataset_name, language_abbr, split="train+validation", use_auth_token=True)
common_voice["test"] = load_dataset(dataset_name, language_abbr, split="test", use_auth_token=True)

Whisper-puheentunnistusmalli edellyttää, että äänitulot ovat 16 kHz mono 16-bittinen etumerkillinen kokonaisluku WAV-tiedostot. Koska Common Voice -tietojoukon näytteenottotaajuus on 48 XNUMX, sinun on ensin näytteistettävä äänitiedostot. Sitten sinun on käytettävä Whisperin ominaisuuspoimijaa äänessä log-mel-spektrogrammin ominaisuuksien poimimiseksi ja käytettävä Whisperin tokenisaattoria kehystettyihin ominaisuuksiin, jotta transkription jokainen lause muunnetaan tunnukseksi. Katso seuraava koodi:

from transformers import WhisperFeatureExtractor
from transformers import WhisperTokenizer feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)
tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task) def prepare_dataset(batch):
# load and resample audio data from 48 to 16kHz
audio = batch["audio"] # compute log-Mel input features from input audio array
batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0] # encode target text to label ids
batch["labels"] = tokenizer(batch["sentence"]).input_ids
return batch #apply the data preparation function to all of our fine-tuning dataset samples using dataset's .map method.
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=2)
common_voice.save_to_disk("marathi-common-voice-processed")
!aws s3 cp --recursive "marathi-common-voice-processed" s3://<Your-S3-Bucket>

Kun olet käsitellyt kaikki harjoitusnäytteet, lataa käsitellyt tiedot Amazon S3:een, jotta voit käyttää käsiteltyä harjoitustietoa hienosäätövaiheessa FastFile liittääksesi S3-tiedoston suoraan paikalliselle levylle kopioimisen sijaan:

from sagemaker.inputs import TrainingInput
training_input_path=s3uri
training = TrainingInput(
s3_data_type='S3Prefix', # Available Options: S3Prefix | ManifestFile | AugmentedManifestFile
s3_data=training_input_path,
distribution='FullyReplicated', # Available Options: FullyReplicated | ShardedByS3Key
input_mode='FastFile'
)

Harjoittele mallia

Esittelyä varten käytämme esikoulutettuna mallina whisper-large-v2 (whisper v3 on nyt saatavilla), joka voidaan tuoda Hugging Face transformers -kirjaston kautta. Voit käyttää 8-bittinen kvantisointi parantaa harjoittelun tehokkuutta entisestään. 8-bittinen kvantisointi tarjoaa muistin optimoinnin pyöristämällä liukulukusta 8-bittisiin kokonaislukuihin. Se on yleisesti käytetty mallipakkaustekniikka, jonka avulla voidaan säästää pienennetystä muistista tinkimättä tarkkuudesta päättelyn aikana liikaa.

Esiopetetun mallin lataamiseksi 8-bittisessä kvantisoidussa muodossa lisäämme yksinkertaisesti argumentin load_in_8bit=True mallin luomisen yhteydessä, kuten seuraavassa koodissa näkyy. Tämä lataa mallin painot, jotka on kvantisoitu 8 bittiin, mikä vähentää muistin tilaa.

from transformers import WhisperForConditionalGeneration model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, load_in_8bit=True, device_map="auto")

Käytämme Hugging Facen LoRA-toteutusta peft paketti. Mallin hienosäätäminen LoRA:lla on neljä vaihetta:

  1. Luo perusmalli (kuten teimme viimeisessä vaiheessa).
  2. Luo kokoonpano (LoraConfig), jossa määritellään LoRA-spesifiset parametrit.
  3. Kääri pohjamalli get_peft_model() saada koulutettavaa PeftModel.
  4. Kouluta PeftModel perusmallina.

Katso seuraava koodi:

from peft import LoraConfig, get_peft_model config = LoraConfig(r=32, lora_alpha=64, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none")
model = get_peft_model(model, config) training_args = Seq2SeqTrainingArguments(
output_dir=args.model_dir,
per_device_train_batch_size=int(args.train_batch_size),
gradient_accumulation_steps=1,
learning_rate=float(args.learning_rate),
warmup_steps=args.warmup_steps,
num_train_epochs=args.num_train_epochs,
evaluation_strategy="epoch",
fp16=True,
per_device_eval_batch_size=args.eval_batch_size,
generation_max_length=128,
logging_steps=25,
remove_unused_columns=False,
label_names=["labels"],
)
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=train_dataset["train"],
eval_dataset=train_dataset.get("test", train_dataset["test"]),
data_collator=data_collator,
tokenizer=processor.feature_extractor,
)

Suorita a SageMaker koulutus työ, tuomme oman Docker-kontin. Voit ladata Docker-kuvan osoitteesta GitHub, jossa ffmpeg4 ja git-lfs on pakattu yhdessä muiden Python-vaatimusten kanssa. Lisätietoja oman Docker-säilön mukauttamisesta toimimaan SageMakerin kanssa on kohdassa Oman harjoituskontin mukauttaminen. Sitten voit käyttää Hugging Face Estimatoria ja aloittaa SageMaker-harjoitustyön:

OUTPUT_PATH= f's3://{BUCKET}/{PREFIX}/{TRAINING_JOB_NAME}/output/' huggingface_estimator = HuggingFace(entry_point='train.sh',
source_dir='./src',
output_path= OUTPUT_PATH,
instance_type=instance_type,
instance_count=1,
# transformers_version='4.17.0',
# pytorch_version='1.10.2',
py_version='py310',
image_uri=<ECR-PATH>,
role=ROLE,
metric_definitions = metric_definitions,
volume_size=200,
distribution=distribution,
keep_alive_period_in_seconds=1800,
environment=environment,
) huggingface_estimator.fit(job_name=TRAINING_JOB_NAME, wait=False)

LoRA:n käyttöönotto mahdollisti Whisperin suuren hienosäätötehtävän suorittamisen yhdellä GPU-esiintymällä (esimerkiksi ml.g5.2xlarge). Vertailun vuoksi, Whisper large täysi hienosäätötehtävä vaatii useita GPU:ita (esimerkiksi ml.p4d.24xlarge) ja paljon pidemmän harjoitusajan. Tarkemmin sanottuna kokeilumme osoitti, että täydellinen hienosäätötehtävä vaatii 24 kertaa enemmän GPU-tuntia LoRA-lähestymistapaan verrattuna.

Arvioi mallin suorituskykyä

Hienosäädetyn Whisper-mallin suorituskyvyn arvioimiseksi laskemme sanan virhesuhteen (WER) pidennetyssä testijoukossa. WER mittaa eron ennustetun transkriptin ja perustotuustranskription välillä. Pienempi WER tarkoittaa parempaa suorituskykyä. Voit suorittaa seuraavan skriptin esikoulutettua mallia ja hienosäädettyä mallia vastaan ​​ja vertailla niiden WER-eroa:

metric = evaluate.load("wer") eval_dataloader = DataLoader(common_voice["test"], batch_size=8, collate_fn=data_collator) model.eval()
for step, batch in enumerate(tqdm(eval_dataloader)):
with torch.cuda.amp.autocast():
with torch.no_grad():
generated_tokens = (
model.generate(
input_features=batch["input_features"].to("cuda"),
decoder_input_ids=batch["labels"][:, :4].to("cuda"),
max_new_tokens=255,
)
.cpu()
.numpy()
)
labels = batch["labels"].cpu().numpy()
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
metric.add_batch(
predictions=decoded_preds,
references=decoded_labels,
)
del generated_tokens, labels, batch
gc.collect()
wer = 100 * metric.compute()
print(f"{wer=}")

Yhteenveto

Tässä viestissä esitimme hienosäädettävän Whisperin, huippuluokan puheentunnistusmallin. Erityisesti käytimme Hugging Facen PEFT LoRAa ja mahdollistimme 8-bittisen kvantisoinnin tehokkaaseen harjoitteluun. Osoitimme myös, kuinka harjoitustyötä suoritetaan SageMakerissa.

Vaikka tämä on tärkeä ensimmäinen askel, on useita tapoja, joilla voit kehittää tätä työtä parantaaksesi entisestään kuiskausmallia. Harkitse jatkossa SageMaker-hajautetun koulutuksen käyttöä koulutuksen skaalaamiseksi paljon suuremmalla tietojoukolla. Tämä antaa mallille mahdollisuuden harjoitella monipuolisempaa ja kattavampaa dataa, mikä parantaa tarkkuutta. Voit myös optimoida viiveen Whisper-mallia käyttäessäsi reaaliaikaisen puheentunnistuksen mahdollistamiseksi. Lisäksi voit laajentaa työtä käsitelläksesi pidempiä äänitranskriptioita, mikä edellyttää muutoksia malliarkkitehtuuriin ja koulutussuunnitelmiin.

Kuittaus

Kirjoittajat kiittävät Paras Mehraa, John Solia ja Evandro Francoa oivaltavasta palautteesta ja julkaisun arvostelusta.


Tietoja Tekijät

Jun Shi on vanhempi ratkaisuarkkitehti Amazon Web Servicesissä (AWS). Hänen tämänhetkisen painopistealueensa ovat AI/ML-infrastruktuuri ja sovellukset. Hänellä on yli vuosikymmenen kokemus FinTech-alalta ohjelmistoinsinöörinä.

DR. Changsha Ma on AI/ML-asiantuntija AWS:ssä. Hän on tekniikan tohtori, jolla on tietojenkäsittelytieteen tohtori, koulutuspsykologian maisterin tutkinto ja vuosien kokemus datatieteestä ja riippumattomasta konsultoinnista AI/ML:ssä. Hän on intohimoinen kone- ja ihmisälyn metodologisten lähestymistapojen tutkimiseen. Työn ulkopuolella hän rakastaa patikointia, ruoanlaittoa, metsästää ruokaa ja viettää aikaa ystävien ja perheen kanssa.

spot_img

Uusin älykkyys

spot_img