import logging
from collections import namedtuple
import numpy as np
import numpy.typing as npt
from badcrossbar import utils
from badcrossbar.computing import solve
logger = logging.getLogger(__name__)
Interconnect = namedtuple("Interconnect", ["word_line", "bit_line"])
Solution = namedtuple("Solution", ["currents", "voltages"])
Currents = namedtuple("Currents", ["output", "device", "word_line", "bit_line"])
Voltages = namedtuple("Voltages", ["word_line", "bit_line"])
[docs]def solution(
resistances: npt.NDArray,
r_i_word_line: float,
r_i_bit_line: float,
applied_voltages: npt.NDArray,
**kwargs
) -> Solution:
"""Extracts branch currents and node voltages of a crossbar in a
convenient form.
Args:
resistances: Resistances of crossbar devices.
r_i_word_line: Interconnect resistance of the word line segments.
r_i_bit_line: Interconnect resistance of the bit line segments.
applied_voltages: Applied voltages.
**node_voltages: If False, None is returned instead of node voltages.
Returns:
Branch currents and node voltages of the crossbar.
"""
r_i = Interconnect(r_i_word_line, r_i_bit_line)
if r_i.word_line == r_i.bit_line == np.inf:
return insulating_interconnect_solution(resistances, applied_voltages, **kwargs)
v = solve.v(resistances, r_i, applied_voltages)
extracted_voltages = voltages(v, resistances, **kwargs)
extracted_currents = currents(extracted_voltages, resistances, r_i, applied_voltages, **kwargs)
if kwargs.get("node_voltages") is not True:
extracted_voltages = None
extracted_solution = Solution(extracted_currents, extracted_voltages)
return extracted_solution
[docs]def currents(
extracted_voltages: Voltages,
resistances: npt.NDArray,
r_i: Interconnect,
applied_voltages: npt.NDArray,
**kwargs
) -> Currents:
"""Extracts crossbar branch currents in a convenient format.
Args:
extracted_voltages: Crossbar node voltages. It has fields `word_line`
and `bit_line` that contain the potentials at the nodes on the word
and bit lines.
resistances: Resistances of crossbar devices.
r_i: Interconnect resistances along the word and bit line segments.
applied_voltages: Applied voltages.
**all_currents: If False, only output currents are returned, while all
the other ones are set to None.
Returns:
Crossbar branch currents. Named tuple has fields `output`, `device`,
`word_line` and `bit_line` that contain output currents, as well as
currents flowing through the devices and interconnect segments of the
word and bit lines.
"""
device_i = device_currents(extracted_voltages, resistances)
output_i = output_currents(extracted_voltages, device_i, r_i)
if kwargs.get("all_currents"):
word_line_i = word_line_currents(extracted_voltages, device_i, r_i, applied_voltages)
bit_line_i = bit_line_currents(extracted_voltages, device_i, r_i)
logger.info("Extracted currents from all branches in the crossbar.")
else:
device_i = word_line_i = bit_line_i = None
logger.info("Extracted output currents.")
extracted_currents = Currents(output_i, device_i, word_line_i, bit_line_i)
return extracted_currents
[docs]def voltages(v: npt.NDArray, resistances: npt.NDArray, **kwargs) -> Voltages:
"""Extracts crossbar node voltages in a convenient format.
Args:
v: Solution to gv = i in a flattened form.
resistances: Resistances of crossbar devices.
Returns:
Crossbar node voltages. It has fields `word_line` and `bit_line` that
contain the potentials at the nodes on the word and bit lines.
"""
word_line_v = word_line_voltages(v, resistances)
bit_line_v = bit_line_voltages(v, resistances)
extracted_voltages = Voltages(word_line_v, bit_line_v)
if kwargs.get("node_voltages"):
logger.info("Extracted node voltages.")
return extracted_voltages
[docs]def word_line_voltages(v: npt.NDArray, resistances: npt.NDArray) -> npt.NDArray:
"""Extracts voltages at the nodes on the word lines.
Args:
v: Solution to gv = i in a flattened form.
resistances: Resistances of crossbar devices.
Returns:
Voltages at the nodes on the word lines.
"""
v_domain = v[
: resistances.size,
]
return utils.distributed_array(v_domain, resistances)
[docs]def bit_line_voltages(v: npt.NDArray, resistances: npt.NDArray) -> npt.NDArray:
"""Extracts voltages at the nodes on the bit lines.
Args:
v: Solution to gv = i in a flattened form.
resistances: Resistances of crossbar devices.
Returns:
Voltages at the nodes on the bit lines.
"""
v_domain = v[
resistances.size :,
]
return utils.distributed_array(v_domain, resistances)
[docs]def output_currents(
extracted_voltages: Voltages, extracted_device_currents: npt.NDArray, r_i: Interconnect
) -> npt.NDArray:
"""Extracts output currents.
Args:
extracted_voltages: Crossbar node voltages. It has fields `word_line`
and `bit_line` that contain the potentials at the nodes on the word
and bit lines.
extracted_device_currents: Currents flowing through crossbar devices.
r_i: Interconnect resistances along the word and bit line segments.
Returns:
Output currents.
"""
if r_i.bit_line > 0:
output_i = (
extracted_voltages.bit_line[
-1,
]
/ r_i.bit_line
)
else:
output_i = np.sum(extracted_device_currents, axis=0)
output_i = np.transpose(output_i)
if output_i.ndim == 1:
output_i = output_i.reshape(1, output_i.shape[0])
return output_i
[docs]def device_currents(extracted_voltages: Voltages, resistances: npt.NDArray):
"""Extracts currents flowing through crossbar devices.
Args:
extracted_voltages: Crossbar node voltages. It has fields `word_line`
and `bit_line` that contain the potentials at the nodes on the word
and bit lines.
resistances: Resistances of crossbar devices.
Returns:
Currents flowing through crossbar devices.
"""
if extracted_voltages.word_line.ndim > 2:
resistances = np.repeat(
resistances[:, :, np.newaxis], extracted_voltages.word_line.shape[2], axis=2
)
v_diff = extracted_voltages.word_line - extracted_voltages.bit_line
device_i = v_diff / resistances
return device_i
[docs]def word_line_currents(
extracted_voltages: Voltages,
extracted_device_currents: npt.NDArray,
r_i: Interconnect,
applied_voltages: npt.NDArray,
) -> npt.NDArray:
"""Extracts currents flowing through interconnect segments along the word
lines.
Args:
extracted_voltages: Crossbar node voltages. It has fields `word_line`
and `bit_line` that contain the potentials at the nodes on the word
and bit lines.
extracted_device_currents: Currents flowing through crossbar devices.
r_i: Interconnect resistances along the word and bit line segments.
applied_voltages: Applied voltages.
Returns:
Currents flowing through interconnect segments along the word lines.
"""
if r_i.word_line > 0:
word_line_i = np.zeros(extracted_device_currents.shape)
if extracted_voltages.word_line.ndim > 2:
v_diff = (
applied_voltages
- extracted_voltages.word_line[
:,
0,
]
)
word_line_i[:, 0,] = (
v_diff / r_i.word_line
)
else:
v_diff = applied_voltages - extracted_voltages.word_line[:, [0]]
word_line_i[:, [0]] = v_diff / r_i.word_line
v_diff = (
extracted_voltages.word_line[
:,
:-1,
]
- extracted_voltages.word_line[
:,
1:,
]
)
word_line_i[:, 1:,] = (
v_diff / r_i.word_line
)
else:
word_line_i = np.repeat(
extracted_device_currents[
:,
-1:,
],
extracted_device_currents.shape[1],
axis=1,
)
for i in range(1, extracted_device_currents.shape[1]):
word_line_i[:, :-i,] += np.repeat(
extracted_device_currents[
:,
-(1 + i) : -i,
],
extracted_device_currents.shape[1] - i,
axis=1,
)
return word_line_i
[docs]def bit_line_currents(
extracted_voltages: Voltages, extracted_device_currents: npt.NDArray, r_i: Interconnect
) -> npt.NDArray:
"""Extracts currents flowing through interconnect segments along the bit
lines.
Args:
extracted_voltages: Crossbar node voltages. It has fields `word_line`
and `bit_line` that contain the potentials at the nodes on the word
and bit lines.
extracted_device_currents: Currents flowing through crossbar devices.
r_i: Interconnect resistances along the word and bit line segments.
Returns:
Currents flowing through interconnect segments along the bit lines.
"""
if r_i.bit_line > 0:
bit_line_i = np.zeros(extracted_device_currents.shape)
v_diff = (
extracted_voltages.bit_line[
:-1,
:,
]
- extracted_voltages.bit_line[
1:,
:,
]
)
bit_line_i[:-1, :,] = (
v_diff / r_i.bit_line
)
if extracted_voltages.bit_line.ndim > 2:
v_diff = extracted_voltages.bit_line[
-1,
:,
]
bit_line_i[-1, :,] = (
v_diff / r_i.bit_line
)
else:
v_diff = extracted_voltages.bit_line[[-1], :]
bit_line_i[[-1], :] = v_diff / r_i.bit_line
else:
bit_line_i = np.zeros(extracted_device_currents.shape)
for i in range(extracted_device_currents.shape[0]):
bit_line_i[i:, :,] += np.repeat(
extracted_device_currents[
i : i + 1,
:,
],
extracted_device_currents.shape[0] - i,
axis=0,
)
return bit_line_i
[docs]def insulating_interconnect_solution(
resistances: npt.NDArray, applied_voltages: npt.NDArray, **kwargs
) -> Solution:
"""Extracts solution when all interconnects are perfectly insulating.
Args:
resistances: Resistances of crossbar devices.
applied_voltages: Applied voltages.
**all_currents: If False, only output currents are returned, while all
the other ones are set to None.
Returns:
Branch currents and node voltages of the crossbar.
"""
extracted_voltages = Voltages(None, None)
if kwargs.get("node_voltages"):
logger.info(
"Warning: all interconnects are perfectly insulating! Node voltages are undefined!"
)
output_i = np.zeros((applied_voltages.shape[1], resistances.shape[1]))
if kwargs.get("all_currents", True):
same_i = np.zeros((resistances.shape[0], resistances.shape[1], applied_voltages.shape[1]))
same_i = utils.squeeze_third_axis(same_i)
device_i = word_line_i = bit_line_i = same_i
logger.info("Extracted currents from all branches in the crossbar.")
else:
device_i = word_line_i = bit_line_i = None
logger.info("Extracted output currents.")
extracted_currents = Currents(output_i, device_i, word_line_i, bit_line_i)
extracted_solution = Solution(extracted_currents, extracted_voltages)
return extracted_solution