void CvDTree::split_node_data()

in apps/traincascade/old_ml_tree.cpp [3034:3315]


void CvDTree::split_node_data( CvDTreeNode* node )
{
    int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
    char* dir = (char*)data->direction->data.ptr;
    CvDTreeNode *left = 0, *right = 0;
    int* new_idx = data->split_buf->data.i;
    int new_buf_idx = data->get_child_buf_idx( node );
    int work_var_count = data->get_work_var_count();
    CvMat* buf = data->buf;
    size_t length_buf_row = data->get_length_subbuf();
    cv::AutoBuffer<uchar> inn_buf(n*(3*sizeof(int) + sizeof(float)));
    int* temp_buf = (int*)(uchar*)inn_buf;

    complete_node_dir(node);

    for( i = nl = nr = 0; i < n; i++ )
    {
        int d = dir[i];
        // initialize new indices for splitting ordered variables
        new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
        nr += d;
        nl += d^1;
    }

    bool split_input_data;
    node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
    node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );

    split_input_data = node->depth + 1 < data->params.max_depth &&
        (node->left->sample_count > data->params.min_sample_count ||
        node->right->sample_count > data->params.min_sample_count);

    // split ordered variables, keep both halves sorted.
    for( vi = 0; vi < data->var_count; vi++ )
    {
        int ci = data->get_var_type(vi);

        if( ci >= 0 || !split_input_data )
            continue;

        int n1 = node->get_num_valid(vi);
        float* src_val_buf = (float*)(uchar*)(temp_buf + n);
        int* src_sorted_idx_buf = (int*)(src_val_buf + n);
        int* src_sample_idx_buf = src_sorted_idx_buf + n;
        const float* src_val = 0;
        const int* src_sorted_idx = 0;
        data->get_ord_var_data(node, vi, src_val_buf, src_sorted_idx_buf, &src_val, &src_sorted_idx, src_sample_idx_buf);

        for(i = 0; i < n; i++)
            temp_buf[i] = src_sorted_idx[i];

        if (data->is_buf_16u)
        {
            unsigned short *ldst, *rdst, *ldst0, *rdst0;
            //unsigned short tl, tr;
            ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
                vi*scount + left->offset);
            rdst0 = rdst = (unsigned short*)(ldst + nl);

            // split sorted
            for( i = 0; i < n1; i++ )
            {
                int idx = temp_buf[i];
                int d = dir[idx];
                idx = new_idx[idx];
                if (d)
                {
                    *rdst = (unsigned short)idx;
                    rdst++;
                }
                else
                {
                    *ldst = (unsigned short)idx;
                    ldst++;
                }
            }

            left->set_num_valid(vi, (int)(ldst - ldst0));
            right->set_num_valid(vi, (int)(rdst - rdst0));

            // split missing
            for( ; i < n; i++ )
            {
                int idx = temp_buf[i];
                int d = dir[idx];
                idx = new_idx[idx];
                if (d)
                {
                    *rdst = (unsigned short)idx;
                    rdst++;
                }
                else
                {
                    *ldst = (unsigned short)idx;
                    ldst++;
                }
            }
        }
        else
        {
            int *ldst0, *ldst, *rdst0, *rdst;
            ldst0 = ldst = buf->data.i + left->buf_idx*length_buf_row +
                vi*scount + left->offset;
            rdst0 = rdst = buf->data.i + right->buf_idx*length_buf_row +
                vi*scount + right->offset;

            // split sorted
            for( i = 0; i < n1; i++ )
            {
                int idx = temp_buf[i];
                int d = dir[idx];
                idx = new_idx[idx];
                if (d)
                {
                    *rdst = idx;
                    rdst++;
                }
                else
                {
                    *ldst = idx;
                    ldst++;
                }
            }

            left->set_num_valid(vi, (int)(ldst - ldst0));
            right->set_num_valid(vi, (int)(rdst - rdst0));

            // split missing
            for( ; i < n; i++ )
            {
                int idx = temp_buf[i];
                int d = dir[idx];
                idx = new_idx[idx];
                if (d)
                {
                    *rdst = idx;
                    rdst++;
                }
                else
                {
                    *ldst = idx;
                    ldst++;
                }
            }
        }
    }

    // split categorical vars, responses and cv_labels using new_idx relocation table
    for( vi = 0; vi < work_var_count; vi++ )
    {
        int ci = data->get_var_type(vi);
        int n1 = node->get_num_valid(vi), nr1 = 0;

        if( ci < 0 || (vi < data->var_count && !split_input_data) )
            continue;

        int *src_lbls_buf = temp_buf + n;
        const int* src_lbls = data->get_cat_var_data(node, vi, src_lbls_buf);

        for(i = 0; i < n; i++)
            temp_buf[i] = src_lbls[i];

        if (data->is_buf_16u)
        {
            unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*length_buf_row +
                vi*scount + left->offset);
            unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*length_buf_row +
                vi*scount + right->offset);

            for( i = 0; i < n; i++ )
            {
                int d = dir[i];
                int idx = temp_buf[i];
                if (d)
                {
                    *rdst = (unsigned short)idx;
                    rdst++;
                    nr1 += (idx != 65535 )&d;
                }
                else
                {
                    *ldst = (unsigned short)idx;
                    ldst++;
                }
            }

            if( vi < data->var_count )
            {
                left->set_num_valid(vi, n1 - nr1);
                right->set_num_valid(vi, nr1);
            }
        }
        else
        {
            int *ldst = buf->data.i + left->buf_idx*length_buf_row +
                vi*scount + left->offset;
            int *rdst = buf->data.i + right->buf_idx*length_buf_row +
                vi*scount + right->offset;

            for( i = 0; i < n; i++ )
            {
                int d = dir[i];
                int idx = temp_buf[i];
                if (d)
                {
                    *rdst = idx;
                    rdst++;
                    nr1 += (idx >= 0)&d;
                }
                else
                {
                    *ldst = idx;
                    ldst++;
                }

            }

            if( vi < data->var_count )
            {
                left->set_num_valid(vi, n1 - nr1);
                right->set_num_valid(vi, nr1);
            }
        }
    }


    // split sample indices
    int *sample_idx_src_buf = temp_buf + n;
    const int* sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf);

    for(i = 0; i < n; i++)
        temp_buf[i] = sample_idx_src[i];

    int pos = data->get_work_var_count();
    if (data->is_buf_16u)
    {
        unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
            pos*scount + left->offset);
        unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*length_buf_row +
            pos*scount + right->offset);
        for (i = 0; i < n; i++)
        {
            int d = dir[i];
            unsigned short idx = (unsigned short)temp_buf[i];
            if (d)
            {
                *rdst = idx;
                rdst++;
            }
            else
            {
                *ldst = idx;
                ldst++;
            }
        }
    }
    else
    {
        int* ldst = buf->data.i + left->buf_idx*length_buf_row +
            pos*scount + left->offset;
        int* rdst = buf->data.i + right->buf_idx*length_buf_row +
            pos*scount + right->offset;
        for (i = 0; i < n; i++)
        {
            int d = dir[i];
            int idx = temp_buf[i];
            if (d)
            {
                *rdst = idx;
                rdst++;
            }
            else
            {
                *ldst = idx;
                ldst++;
            }
        }
    }

    // deallocate the parent node data that is not needed anymore
    data->free_node_data(node);
}