首页 文章

BST - 间隔删除/多个节点删除

提问于
浏览
-1

假设我有一个二叉搜索树,其中我应该按照标准输入给我的顺序插入N个唯一编号的键,然后我将删除所有带有键的节点I = [min,max]以及与这些节点相邻的所有连接 . 这给了我很多小树,我将以特定的方式合并在一起 . 更准确地描述问题:

给定包含不同密钥的BST和间隔I,间隔删除分两个阶段进行 . 在第一阶段,它将删除其密钥位于I中的所有节点以及与删除的节点相邻的所有边缘 . 让结果图包含k个连通分量T1,...,Tk . 每个组件都是BST,其中根是原始BST中该组件的所有节点中具有最小深度的节点 . 我们假设树的序列Ti被排序,使得对于每个i <j,Ti中的所有键都小于Tj中的键 . 在第二阶段期间,树木Ti合并在一起以形成一个BST . 我们用Merge(T1,...,Tk)表示这个操作 . 其输出定期反复定义如下:

编辑:我也应该删除连接节点的任何边缘,这些边缘由给定的间隔分开,这意味着在示例2中连接节点10和20的边缘被删除,因为间隔[13,15]是“在它们之间”因此将它们分开 .

对于空的树序列,Merge()给出一个空的BST . 对于包含树T的单元素序列,Merge(T)= T.对于树序列T1,...,Tk,其中k> 1,设A1 <A2 <... <An是键序列存储在所有树T1,...,Tk的并集中,按升序排序 . 设m =⌊(1 k)/2⌋,Ts为包含Am的树 . 然后,Merge(T1,...,Tk)给出通过合并三个树Ts,TL = Merge(T1,...,Ts-1)和TR = Merge(Ts 1,...,Tk)创建的树T. ) . 通过 Build 以下两个链接来合并这些树:TL被附加为存储Ts的最小密钥的节点的左子树,并且TR被附加为存储Ts的最大密钥的节点的右子树 .

在我这样做之后,我的任务是找到生成的合并树的深度D和深度为D-1的节点数 . 即使对于100000个节点的树(第4个例子),我的程序也应在几秒钟内完成 .

我的问题是,我还没有得到如何做到这一点或甚至开始的线索 . 我设法在删除之前构造所需的树,但那是关于那个 . 我很感激能够实施一个解决这个问题或任何建议的计划 . 优选地,在一些C-ish编程语言中 .

例子:

输入(第一个数字是要插入空树中的键数,第二个是按给定顺序插入的唯一键,第三行包含两个数字,表示要删除的间隔):

13    
10 5 8 6 9 7 20 15 22 13 17 16 18  
8 16

正确的程序输出: 3 3 ,第一个数字是深度D,第二个节点数量是深度D-1

输入:

13
10 5 8 6 9 7 20 15 22 13 17 16 18
13 15

正确输出: 4 3

pictures of the two examples

示例3:https://justpaste.it/1du6l正确输出: 13 6

示例4:link正确输出: 58 9

1 回答

  • 2

    这是一个很大的答案,我将在高层讨论 . 请查看来源以获取详细信息,或者在评论中提出澄清说明 .

    Global Variables

    • vector<Node*> roots :存储所有新树的根 .

    • map<Node*,int> smap :对于每个新树,存储它的大小

    • vector<int> prefixroots 向量的前缀和,用于在 merge 中轻松进行二进制搜索

    Functions

    • inorder :查找BST的大小(所有呼叫合并为O(N))

    • delInterval :主题是,如果root不是_439574,那么孩子可能是新树的根源 . 最后两个 if 检查编辑中的特殊边缘 . 为每个节点(后期订单)执行此操作 . (上))

    • merge :将位于 start 的所有新根合并到 roots 中的 end 索引 . 首先,我们在 total 中找到新树的总成员(使用 roots 的前缀 - 总和,即 prefix ) . mid 在您的问题中表示 m . ind 是包含 mid -th节点的根索引,我们在 root 变量中检索它 . 现在递归地构建左/右子树并将它们添加到左/右节点中 . O(N)复杂性 .

    • traverse :在 level map中,计算每个树深度的节点数 . (O(N.logN),unordered_map将其变为O(N))

    现在的代码(不要惊慌!!!):

    #include <bits/stdc++.h>
    using namespace std;
    
    int N = 12;
    
    struct Node
    {
        Node* parent=NULL,*left=NULL,*right = NULL;
        int value;
        Node(int x,Node* par=NULL) {value = x;parent = par;}
    };
    
    void insert(Node* root,int x){
        if(x<root->value){
            if(root->left) insert(root->left,x);
            else root->left = new Node(x,root);
        }
        else{
            if(root->right) insert(root->right,x);
            else root->right = new Node(x,root);
        }
    }
    
    int inorder(Node* root){
        if(root==NULL) return 0;
        int l = inorder(root->left);
        return l+1+inorder(root->right);
    }
    
    vector<Node*> roots;
    map<Node*,int> smap;
    vector<int> prefix;
    
    Node* delInterval(Node* root,int x,int y){
        if(root==NULL) return NULL;
        root->left = delInterval(root->left,x,y);
        root->right = delInterval(root->right,x,y);
        if(root->value<=y && root->value>=x){
            if(root->left) roots.push_back(root->left);
            if(root->right) roots.push_back(root->right);
            return NULL;
        }
        if(root->value<x && root->right && root->right->value>y) {
            roots.push_back(root->right);
            root->right = NULL;
        }
        if(root->value>y && root->left && root->left->value<x) {
            roots.push_back(root->left);
            root->left = NULL;
        }
        return root;
    
    }
    Node* merge(int start,int end){
        if(start>end) return NULL;
        if(start==end) return roots[start];
        int total = prefix[end] - (start>0?prefix[start-1]:0);//make sure u get this line
        int mid = (total+1)/2 + (start>0?prefix[start-1]:0); //or this won't make sense
        int ind = lower_bound(prefix.begin(),prefix.end(),mid) - prefix.begin();
        Node* root = roots[ind];
        Node* TL = merge(start,ind-1);
        Node* TR = merge(ind+1,end);
        Node* temp = root;
        while(temp->left) temp = temp->left;
        temp->left = TL;
        temp = root;
        while(temp->right) temp = temp->right;
        temp->right = TR;
        return root;
    }
    
    void traverse(Node* root,int depth,map<int, int>& level){
        if(!root) return;
        level[depth]++;
        traverse(root->left,depth+1,level);
        traverse(root->right,depth+1,level);
    }
    
    int main(){
        srand(time(NULL));
        cin>>N;
        int* arr = new int[N],start,end;
        for(int i=0;i<N;i++) cin>>arr[i];
        cin>>start>>end;
    
        Node* tree = new Node(arr[0]); //Building initial tree
        for(int i=1;i<N;i++) {insert(tree,arr[i]);}
    
        Node* x = delInterval(tree,start,end); //deleting the interval
        if(x) roots.push_back(x);
    
        //sort the disconnected roots, and find their size
        sort(roots.begin(),roots.end(),[](Node* r,Node* v){return r->value<v->value;}); 
        for(auto& r:roots) {smap[r] = inorder(r);}
    
        prefix.resize(roots.size()); //prefix sum root sizes, to cheaply find 'root' in merge
        prefix[0] = smap[roots[0]];
        for(int i=1;i<roots.size();i++) prefix[i]= smap[roots[i]]+prefix[i-1];
    
        Node* root = merge(0,roots.size()-1); //merge all trees
        map<int, int> level; //key=depth, value = no of nodes in depth
        traverse(root,0,level); //find number of nodes in each depth
    
        int depth = level.rbegin()->first; //access last element's key i.e total depth
        int at_depth_1 = level[depth-1]; //no of nodes before
        cout<<depth<<" "<<at_depth_1<<endl; //hoorray
    
        return 0;
    }
    

相关问题