我编写了一个父类,在其中定义了所有子类都应具有的一些函数(例如__mul __,__ truediv__等)。这些函数在执行后应保留子类的类型。
这里有一些代码来解释我的意思:
class Magnet():
def __init__(self, length, strength):
self.length = length
self.strength = strength
return
def __mul__(self, other):
if np.isscalar(other):
return Magnet(self.length, self.strength * other)
else:
return NotImplemented
class Quadrupole(Magnet):
def __init__(self, length, strength, name):
super().__init__(length, strength)
self.name = name
return
现在,如果我这样做:
Quad1 = Quadrupole(2, 10, 'Q1')
Quad2 = Quad1 * 2
然后Quad1的类型为'__main __。Quadrupole',Quad2的类型为'__main __。Magnet'。
我想知道如何做到这一点,以便保留孩子的类型,而不将其重铸为父类型。一种解决方案是在子类中重新定义这些功能并进行更改
if np.isscalar(other):
return Magnet(self.length, self.strength * other)
到
if np.isscalar(other):
return Quadrupole(self.length, self.strength * other)
但是进行继承的主要原因是不复制粘贴代码。也许像super()一样,但是向下或类类型的占位符...
感谢您的帮助。
使用
return type(self)(self.length, self.strength * other)
魅力。因为我忘了在Magnet .__ init __()中添加“ name”参数(我的原始代码确实这样做,但是在简化示例时弄乱了),所以会引发错误。
我在这里也发现了相同的问题:在__add__运算符中返回相同子类的对象
您可以使用获取类型type(self)
并从中创建一个新对象。
def __mul__(self, other):
if np.isscalar(other):
return type(self)(self.length, self.strength * other)
raise NotImplemented
(也提出NotImplemented
而不是返回它。)
现在使用您的代码将导致:
return type(self)(self.length, self.strength * other)
TypeError: __init__() missing 1 required positional argument: 'name'
这将需要Quadrupole
具有的默认参数name
。
class Quadrupole(Magnet):
def __init__(self, length, strength, name='unknown'):
super().__init__(length, strength)
self.name = name
而且您的代码很高兴,但您可能不满意。其原因是,你现在失去了对信息name
的的Quadrupole
类。
您将返回该类的新实例,有时这不是必需的,并且您可以更改旧类。这会将您的代码简化为:
def __mul__(self, other):
if np.isscalar(other):
self.strength *= other.strength
return self
raise NotImplemented
这会改变您的旧实例。
解决方案1的主要问题是您丢失了信息,因为您正在创建新的类。现在可能的替代方法是只复制该类。不幸的是,复制一个类并不总是那么简单。
基于这个SO问题,deepcopy
在这种情况下,使用可能会有效,但是如果您具有复杂的类结构,则可能必须实现__copy__
以获得所需的内容。
def __mul__(self, other):
if np.isscalar(other):
class_copy = deepcopy(self)
class_copy.strength *= other
return class_copy
raise NotImplemented
可以选择提供一种__copy__
方法。对于提供的代码段,这不是必需的,但是在更复杂的情况下,则有必要。
def __copy__(self):
return Quadrupole(self.length, self.strength, self.name)
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句