[Exercism][Rust] Parallel Letter Frequency

題目網址(需登入 Exercism)

簡單來說,就是要用 concurrency 的方式,將「計算文章中每個字母(數字和標點不計)出現的次數」,分配給 worker_count 個 thread 去做,最後再整合在一起。

輸入的資料,第一個是 &str 的 array slice,第二個是 worker_count

問題 1:如何將工作/資料平均分配為 worker_count 份?

第一個直覺的想法,就是從 array 的第 0 個開始,依序分配給不同的 worker。

假設 i 是 worker 的 index,那麼第 i 個 worker 所需要處理的資料可以利用 iterator 挑出來:

input.iter()
    .skip(i)
    .step_by(work_count)
    //...

另外一個想法,是讓資料盡可能連續:

根據我的實驗,後者稍微比前者快一點點。但比較麻煩的是需要計算要怎麼分割。

首先,每個 worker 至少都會分配到 input.len() / worker_count 行。這樣會剩下 r = input.len() % worker_count 行。接下來只要前 r 個 worker 一人多認領一行就可以了。假設 worker 的 index 為 i,只要 m > i,就再多加一行。因此每個 worker 分配到的行數為:

let length = input.len() / worker_count +
    if input.len() % worker_count > i {
        1
    } else {
        0
    };

至於要從哪一行開始拿,就只要把前面 worker 的 length 累加起來即可。

所以最後的程式碼如下:

    let quotient = input.len() / worker_count;
    let reminder = input.len() % worker_count;
    let mut start_index = 0;
    for i in 0..worker_count {
        let length = quotient + if reminder > i { 1 } else { 0 };
        if length == 0 {
            break;
        }
        let v: Vec<_> = (&input[start_index..(start_index + length)])
            .into_iter()
        //...
        start_index += length;
    }

問題 2:如何處理傳入 thread 中的資料?

由於 spawned thread 可能會活得比 main thread 還長,因此若是 main thread 要把資料傳給 spawned thread,不能傳一般的 reference 過去(要是在 spawned thread 結束之前,main thread 就結束了,那reference 不就找不到資料了嗎?)。因此最直覺的作法,是把要傳入的資料轉成 owned data,然後利用 move 直接把所有權交給 spawned thread:

    for i in 0..worker_count {
        //....
        let v: Vec<_> = (&input[start_index..(start_index + length)])
            .into_iter()
            .map(|&s| String::from(s))
            .collect();
 
        handles.push(thread::spawn(move || {
            v.into_iter()           // <-- v 被 move 進來
                .for_each(|s| {     // s 的型別為 String
            //....

建立 v 是省不掉的。但把所有的 &str 轉成 String 實在很花時間。如果我們為 input: &[&str] 中的 &str 加上 static 的 lifetime 限制,就可以保證 &str 的存活時間比 spawned thread 長,這樣就可以直接把 &str 傳進 spawned thread 中了:

pub fn frequency(input: &[&'static str], worker_count: usize) -> HashMap<char, usize> {
    //...
    for i in 0..worker_count {
        //....
        let v: Vec<_> = (&input[start_index..(start_index + length)])
            .into_iter()
            // 到這邊 item 的 type 是 &&str,因此需要 deref 一下...
            .map(|&s| s)
            .collect();
 
        handles.push(thread::spawn(move || {
            v.into_iter().for_each(|s| {    // 此處 s 的型別就是 &str,而非 String
                //...

問題 3:如何整合所有的資料?

每個 spawned thread 都會把自己負責的部份存進自己的 hash map 中,最後還是要整合成一個的。我的作法是讓每個 spawned thread 都回傳自己的那一份 hash map,然後在 main thread 中整合。

    let mut handles = vec![];
    for i in 0..worker_count {
        //...
        handles.push(thread::spawn(move || {
            let mut result = HashMap::new();
            //...
            result
        }));
        //...
    }

    handles
        .into_iter()
        .map(|h| h.join().unwrap())
        .fold(HashMap::new(), |mut acc, single| {
            single.into_iter().for_each(|(ch, count)| {
                let entry = acc.entry(ch).or_insert(0);
                *entry += count;
            });
            acc
        })

另一個想法,是利用 ArcMutex,把 main thread 中的 hash map 傳進 spawned thread 中,由每個 spawned thread 自行把結果塞進最後的 hash map 中。這個方法可行,但實驗結果沒有快多少,而且資料建立起來又比較麻煩,所以最後就沒這樣做了。

    let result = Arc::new(Mutex::new(HashMap::new()));
    for i in 0..worker_count {
        //...
        let r = Arc::clone(&result);
        handles.push(thread::spawn(move || {
            //...
            result.into_iter().for_each(|(ch, count)| {
                let mut m = r.lock().unwrap();
                let entry = m.entry(ch).or_insert(0);
                *entry += count;
            })
            //...
        }));
        //..
    }
    handles
        .into_iter()
        .for_each(|h| h.join().unwrap());

    // 把資料從 Arc<Mutex<>> 中拿出來的方法。
    // 參考資料:https://stackoverflow.com/questions/29177449/how-to-take-ownership-of-t-from-arcmutext
    Arc::try_unwrap(result).unwrap().into_inner().unwrap()

完整程式

use std::collections::HashMap;
use std::thread;

pub fn frequency(input: &[&'static str], worker_count: usize) -> HashMap<char, usize> {
    let mut handles = vec![];
    let quotient = input.len() / worker_count;
    let reminder = input.len() % worker_count;
    let mut start_index = 0;
    for i in 0..worker_count {
        let length = quotient + if reminder > i { 1 } else { 0 };
        if length == 0 {
            break;
        }
        let v: Vec<_> = (&input[start_index..(start_index + length)])
            .into_iter()
            .map(|&s| s)
            .collect();
        handles.push(thread::spawn(move || {
            let mut result = HashMap::new();
            v.into_iter().for_each(|s| {
                s.chars()
                    .filter(|c| c.is_alphabetic())
                    .map(|c| c.to_ascii_lowercase())
                    .for_each(|c| {
                        *(result.entry(c).or_insert(0)) += 1;
                    })
            });
            result
        }));
        start_index += length;
    }

    handles
        .into_iter()
        .map(|h| h.join().unwrap())
        // merges results from different workers into one
        .fold(HashMap::new(), |mut acc, single| {
            single.into_iter().for_each(|(ch, count)| {
                let entry = acc.entry(ch).or_insert(0);
                *entry += count;
            });
            acc
        })
}

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *