首页 文章

多子集和计算

提问于
浏览
6

我有2套,集合A包含一组随机数,而集合B的元素是集合A的子集的总和 .

例如,

A = [8, 9, 15, 15, 33, 36, 39, 45, 46, 60, 68, 73, 80, 92, 96]

B = [183, 36, 231, 128, 137]

我想找到哪个数字是哪个子集的总和与这样的数据 .

S = [[45, 46, 92], [36], [8, 15, 39, 73, 96], [60, 68], [9, 15, 33, 80]]

我能用python编写非常愚蠢的暴力代码 .

class SolvedException(BaseException):
    pass

def solve(sums, nums, answer):
    num = nums[-1]

    for i in range(0, len(sums)):
        sumi = sums[i]
        if sumi == 0:
            continue
        elif sumi - num < 0:
            continue
        answer[i].append(num)

        sums[i] = sumi - num

        if len(nums) != 1:
            solve(sums, nums[:-1], answer)
        elif sumi - num == 0:
            raise SolvedException(answer)

        sums[i] = sumi

        answer[i].pop()

try:
    solve(B, A, [list() for i in range(0, len(B))])
except SolvedException as e:
    print e.args[0]

这段代码适用于小型数据,但计算数据需要数十亿年(有71个数字和10个总和) .

我可以使用一些更好的算法或优化 .

抱歉,我的英文不好,代码也很糟糕 .


编辑:对不起,我意识到我没有准确地描述问题 .

由于 A 中的每个元素都用于制作B的元素, sum(A) == sum(B)

另外,设置 S must be 分区设置 A .

2 回答

  • 9

    这被称为子集和问题,并且它是众所周知的NP完全问题 . 所以基本上没有有效的解决方案 . 参见例如https://en.wikipedia.org/wiki/Subset_sum_problem

    但是,如果您的数字N不是太大,则使用动态编程的伪多项式算法:您从左到右读取列表A并保留可行且小于N的总和列表 . 如果您知道该数字对于给定的A,这是可行的,你可以很容易地获得A [a]可行的那些 . 因此动态编程 . 它通常足够快,可以解决您在那里出现的尺寸问题 .

    这是一个Python快速解决方案:

    def subsetsum(A, N):
        res = {0 : []}
        for i in A:
            newres = dict(res)
            for v, l in res.items():
                if v+i < N:
                    newres[v+i] = l+[i]
                elif v+i == N:
                    return l+[i]
            res = newres
        return None
    

    然后

    >>> A = [8, 9, 15, 15, 33, 36, 39, 45, 46, 60, 68, 73, 80, 92, 96]
    >>> subsetsum(A, 183)
    [15, 15, 33, 36, 39, 45]
    

    OP编辑后:

    现在我正确地理解了你的问题,我仍然认为你的问题可以有效地解决,前提是你有一个有效的子集求解器:我在B上使用分而治之的解决方案:

    • 将B切成两个近似相等的部分B1和B2

    • 使用子集求和求解器在A中搜索其总和等于sum(B1)的所有子集S.
      每个这样的S

    • 调用递归求解(S,B1)并求解(A - S,B2)

    • 如果两者都成功,你就有了解决方案

    但是,对于我建议的动态编程解决方案,下面的(71,10)问题是遥不可及的 .


    顺便说一句,这里是你的问题的快速解决方案,不使用分而治之,但它包含我的动态求解器的正确改编,以获得所有解决方案:

    class NotFound(BaseException):
        pass
    
    from collections import defaultdict
    def subset_all_sums(A, N):
        res = defaultdict(set, {0 : {()}})
        for nn, i in enumerate(A):
            # perform a deep copy of res
            newres = defaultdict(set)
            for v, l in res.items():
                newres[v] |= set(l)
                for v, l in res.items():
                    if v+i <= N:
                        for s in l:
                            newres[v+i].add(s+(i,))
                            res = newres
                            return res[N]
    
    def list_difference(l1, l2):
        ## Similar to merge.
        res = []
        i1 = 0; i2 = 0
        while i1 < len(l1) and i2 < len(l2):
            if l1[i1] == l2[i2]:
                i1 += 1
                i2 += 1
            elif l1[i1] < l2[i2]:
                res.append(l1[i1])
                i1 += 1
            else:
                raise NotFound
                while i1 < len(l1):
                    res.append(l1[i1])
                    i1 += 1
                    return res
    
    def solve(A, B):
        assert sum(A) == sum(B)
        if not B:
            return [[]]
            res = []
            ss = subset_all_sums(A, B[0])
            for s in ss:
                rem = list_difference(A, s)
                for sol in solve(rem, B[1:]):
                    res.append([s]+sol)
                    return res
    

    然后:

    >>> solve(A, B)
    [[(15, 33, 39, 96), (36,), (8, 15, 60, 68, 80), (9, 46, 73), (45, 92)],
     [(15, 33, 39, 96), (36,), (8, 9, 15, 46, 73, 80), (60, 68), (45, 92)],
     [(8, 15, 15, 33, 39, 73), (36,), (9, 46, 80, 96), (60, 68), (45, 92)],
     [(15, 15, 73, 80), (36,), (8, 9, 33, 39, 46, 96), (60, 68), (45, 92)],
     [(15, 15, 73, 80), (36,), (9, 39, 45, 46, 92), (60, 68), (8, 33, 96)],
     [(8, 33, 46, 96), (36,), (9, 15, 15, 39, 73, 80), (60, 68), (45, 92)],
     [(8, 33, 46, 96), (36,), (15, 15, 60, 68, 73), (9, 39, 80), (45, 92)],
     [(9, 15, 33, 46, 80), (36,), (8, 15, 39, 73, 96), (60, 68), (45, 92)],
     [(45, 46, 92), (36,), (8, 15, 39, 73, 96), (60, 68), (9, 15, 33, 80)],
     [(45, 46, 92), (36,), (8, 15, 39, 73, 96), (15, 33, 80), (9, 60, 68)],
     [(45, 46, 92), (36,), (15, 15, 60, 68, 73), (9, 39, 80), (8, 33, 96)],
     [(45, 46, 92), (36,), (9, 15, 15, 39, 73, 80), (60, 68), (8, 33, 96)],
     [(9, 46, 60, 68), (36,), (8, 15, 39, 73, 96), (15, 33, 80), (45, 92)]]
    
    >>> %timeit solve(A, B)
    100 loops, best of 3: 10.5 ms per loop
    

    因此,对于这个大小的问题来说这是非常快的,尽管这里的优化注意到了 .

  • 1

    一个完整的解决方案,可以计算所有方式来完成总计 . 我使用int作为速度和内存使用的特征集: 19='0b10011' 代表 [A[0],A[1],A[4]]=[8,9,33] 这里 .

    A = [8, 9, 15, 15, 33, 36, 39, 45, 46, 60, 68, 73, 80, 92, 96]
    B =[183, 36, 231, 128, 137]
    
    def subsetsum(A,N):
        res=[[0]]+[[] for i in range(N)]
        for i,a in enumerate(A):
            k=1<<i        
            stop=[len(l) for l in res] 
            for shift,l in enumerate(res[:N+1-a]):
                n=a+shift   
                ln=res[n]
                for s in l[:stop[shift]]: ln.append(s+k)
        return res
    
    res = subsetsum(A,max(B))
    solB = [res[b] for b in B]
    exactsol = ~-(1<<len(A))
    
    def decode(answer):
        return [[A[i] for i,b in enumerate(bin(sol)[::-1]) if b=='1'] for sol in answer] 
    
    def solve(i,currentsol,answer):
            if currentsol==exactsol : print(decode(answer))
            if i==len(B): return
            for sol in solB[i]:
                    if not currentsol&sol:
                        answer.append(sol)
                        solve(i+1,currentsol+sol,answer)
                        answer.pop()
    

    用于:

    solve(0,0,[])
    
    [[9, 46, 60, 68], [36], [8, 15, 39, 73, 96], [15, 33, 80], [45, 92]]
    [[9, 46, 60, 68], [36], [8, 15, 39, 73, 96], [15, 33, 80], [45, 92]]
    [[8, 15, 15, 33, 39, 73], [36], [9, 46, 80, 96], [60, 68], [45, 92]]
    [[9, 15, 33, 46, 80], [36], [8, 15, 39, 73, 96], [60, 68], [45, 92]]
    [[9, 15, 33, 46, 80], [36], [8, 15, 39, 73, 96], [60, 68], [45, 92]]
    [[15, 15, 73, 80], [36], [9, 39, 45, 46, 92], [60, 68], [8, 33, 96]]
    [[15, 15, 73, 80], [36], [8, 9, 33, 39, 46, 96], [60, 68], [45, 92]]
    [[45, 46, 92], [36], [15, 15, 60, 68, 73], [9, 39, 80], [8, 33, 96]]
    [[45, 46, 92], [36], [9, 15, 15, 39, 73, 80], [60, 68], [8, 33, 96]]
    [[45, 46, 92], [36], [8, 15, 39, 73, 96], [60, 68], [9, 15, 33, 80]]
    [[45, 46, 92], [36], [8, 15, 39, 73, 96], [15, 33, 80], [9, 60, 68]]
    [[45, 46, 92], [36], [8, 15, 39, 73, 96], [60, 68], [9, 15, 33, 80]]
    [[45, 46, 92], [36], [8, 15, 39, 73, 96], [15, 33, 80], [9, 60, 68]]
    [[15, 33, 39, 96], [36], [8, 15, 60, 68, 80], [9, 46, 73], [45, 92]]
    [[15, 33, 39, 96], [36], [8, 9, 15, 46, 73, 80], [60, 68], [45, 92]]
    [[15, 33, 39, 96], [36], [8, 15, 60, 68, 80], [9, 46, 73], [45, 92]]
    [[15, 33, 39, 96], [36], [8, 9, 15, 46, 73, 80], [60, 68], [45, 92]]
    [[8, 33, 46, 96], [36], [15, 15, 60, 68, 73], [9, 39, 80], [45, 92]]
    [[8, 33, 46, 96], [36], [9, 15, 15, 39, 73, 80], [60, 68], [45, 92]]
    

    注意,当两个 15 不在同一子集中时,解决方案加倍 .

    它解决了独特的解决方案问题:

    A=[1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, 1008, 1009, 1010, 1011,
       1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023,
       1024, 1025, 1026, 1027, 1028, 1029, 1030, 1031, 1032, 1033, 1034, 1035,
       1036, 1037, 1038, 1039, 1040, 1041, 1042, 1043, 1044, 1045, 1046, 1047, 
       1048, 1049]
    
    B=[5010, 5035, 5060, 5085, 5110, 5135, 5160, 5185, 5210, 5235]
    

    一秒钟不幸的是,它还没有针对(71,10)问题进行优化 .

    另一个是纯粹的动态编程精神:

    @functools.lru_cache(max(B))
    def solutions(n):
        if n==0 : return set({frozenset()}) #{{}}
        if n<0 :  return set()
        sols=set()
        for i,a in enumerate(A):
                for s in solutions(n-a):
                    if i not in s : sols.add(s|{i})
        return sols
    
    def decode(answer): return([[A[i] for i in sol] for sol in answer]) 
    
    def solve(B=B,currentsol=set(),answer=[]):
        if len(currentsol)==len(A) : sols.append(decode(answer))
        if B:
            for sol in solutions(B[0]):
                if set.isdisjoint(currentsol,sol):
                    solve(B[1:],currentsol|sol,answer+[sol]) 
    
    sols=[];solve()
    

相关问题