@@ -287,54 +287,58 @@ def swap(self):
287
287
for i in pt ])
288
288
return new_bvdd .link (swap2right_bvdd , return_tuples )
289
289
290
- def unapply (self , a_rt_inv , rt = None , rt_inv = None , index = 0 ):
290
+ def unapply (self , b_rt_inv , rt = None , rt_inv = None , index = 0 ):
291
291
new_bvdd = type (self )({})
292
- a_rt_inv = a_rt_inv if a_rt_inv is not None else {}
292
+ b_rt_inv = b_rt_inv if b_rt_inv is not None else {}
293
293
rt = rt if rt is not None else {}
294
294
rt_inv = rt_inv if rt_inv is not None else {}
295
295
s2o = self .get_s2o ()
296
296
for inputs in s2o :
297
297
output = s2o [inputs ]
298
298
if isinstance (output , BVDD ):
299
- output , a_rt_inv , rt , rt_inv = output .unapply (a_rt_inv , rt , rt_inv , index + 1 )
299
+ output , b_rt_inv , rt , rt_inv = output .unapply (b_rt_inv , rt , rt_inv , index + 1 )
300
300
new_bvdd .set (inputs , output )
301
301
else :
302
302
if output not in rt_inv :
303
303
rt_inv [output ] = len (rt_inv ) + 1
304
- if output not in a_rt_inv :
305
- a_rt_inv [output ] = len (a_rt_inv ) + 1
306
- rt [rt_inv [output ]] = a_rt_inv [output ]
304
+ if output not in b_rt_inv :
305
+ b_rt_inv [output ] = len (b_rt_inv ) + 1
306
+ rt [rt_inv [output ]] = b_rt_inv [output ]
307
307
new_bvdd .set (inputs , rt_inv [output ])
308
- return new_bvdd .reduce_SBDD ().reduce_BVDD (index ), a_rt_inv , rt , rt_inv
308
+ return new_bvdd .reduce_SBDD ().reduce_BVDD (index ), b_rt_inv , rt , rt_inv
309
309
310
- def upsample (self , level , a_rt_inv = None , b_cs = None , b_rts = None , index = 0 ):
310
+ def upsample (self , level , a_rt_inv = None , b_rt_inv = None , b_cs = None , b_rts = None , index = 0 ):
311
311
assert level > 0
312
312
a_c = type (self )({})
313
313
a_rt_inv = a_rt_inv if a_rt_inv is not None else {}
314
+ b_rt_inv = b_rt_inv if b_rt_inv is not None else {}
314
315
b_cs = b_cs if b_cs is not None else {}
315
316
b_rts = b_rts if b_rts is not None else {}
316
317
s2o = self .get_s2o ()
317
318
for inputs in s2o :
318
319
output = s2o [inputs ]
319
320
if isinstance (output , BVDD ):
320
321
if index < 2 ** (level - 1 ) - 1 :
321
- output , a_rt_inv , b_cs , b_rts = output .upsample (level ,
322
- a_rt_inv , b_cs , b_rts , index + 1 )
322
+ output , a_rt_inv , b_rt_inv , b_cs , b_rts = output .upsample (level ,
323
+ a_rt_inv , b_rt_inv , b_cs , b_rts , index + 1 )
323
324
a_c .set (inputs , output )
324
325
else :
325
- output , a_rt_inv , rt , rt_inv = output .unapply (a_rt_inv )
326
- if output not in a_rt_inv :
327
- a_rt_inv [output ] = len (a_rt_inv ) + 1
328
- b_cs [a_rt_inv [output ]] = output
329
- b_rts [a_rt_inv [output ]] = rt
330
- a_c .set (inputs , a_rt_inv [output ])
326
+ output , b_rt_inv , rt , rt_inv = output .unapply (b_rt_inv )
327
+ key = (output , tuple (rt .values ()))
328
+ if key not in a_rt_inv :
329
+ a_rt_inv [key ] = len (a_rt_inv ) + 1
330
+ b_cs [a_rt_inv [key ]] = output
331
+ b_rts [a_rt_inv [key ]] = rt
332
+ a_c .set (inputs , a_rt_inv [key ])
331
333
else :
332
334
if output not in a_rt_inv :
333
335
a_rt_inv [output ] = len (a_rt_inv ) + 1
336
+ if output not in b_rt_inv :
337
+ b_rt_inv [output ] = len (b_rt_inv ) + 1
334
338
b_cs [a_rt_inv [output ]] = BVDD .constant (1 )
335
- b_rts [a_rt_inv [output ]] = {1 :a_rt_inv [output ]}
339
+ b_rts [a_rt_inv [output ]] = {1 :b_rt_inv [output ]}
336
340
a_c .set (inputs , a_rt_inv [output ])
337
- return a_c .reduce_SBDD ().reduce_BVDD (index ), a_rt_inv , b_cs , b_rts
341
+ return a_c .reduce_SBDD ().reduce_BVDD (index ), a_rt_inv , b_rt_inv , b_cs , b_rts
338
342
339
343
def downsample (self , bvdds , return_tuples ):
340
344
# apply to bvdds may reduce to constants
0 commit comments