cheep-crator-2/vendor/matrixmultiply/testdefs/testdefs.rs

187 lines
5.1 KiB
Rust

use matrixmultiply::{sgemm, dgemm};
#[cfg(feature="cgemm")]
use matrixmultiply::{cgemm, zgemm, CGemmOption};
// Common code for tests - generic treatment of f32, f64, c32, c64 and Gemm
pub trait Float : Copy + std::fmt::Debug + PartialEq {
fn zero() -> Self;
fn one() -> Self;
// construct as number x
fn from(x: i64) -> Self;
// construct as number x + yi, but ignore y if not complex
fn from2(x: i64, _y: i64) -> Self { Self::from(x) }
fn nan() -> Self;
fn real(self) -> Self { self }
fn imag(self) -> Self;
fn is_nan(self) -> bool;
fn is_complex() -> bool { false }
fn diff(self, rhs: Self) -> Self;
// absolute value as f64
fn abs_f64(self) -> f64;
fn relative_error_scale() -> f64;
}
impl Float for f32 {
fn zero() -> Self { 0. }
fn one() -> Self { 1. }
fn from(x: i64) -> Self { x as Self }
fn nan() -> Self { 0./0. }
fn imag(self) -> Self { 0. }
fn is_nan(self) -> bool { self.is_nan() }
fn diff(self, rhs: Self) -> Self { self - rhs }
fn abs_f64(self) -> f64 { self.abs() as f64 }
fn relative_error_scale() -> f64 { 1e-6 }
}
impl Float for f64 {
fn zero() -> Self { 0. }
fn one() -> Self { 1. }
fn from(x: i64) -> Self { x as Self }
fn nan() -> Self { 0./0. }
fn imag(self) -> Self { 0. }
fn is_nan(self) -> bool { self.is_nan() }
fn diff(self, rhs: Self) -> Self { self - rhs }
fn abs_f64(self) -> f64 { self.abs() }
fn relative_error_scale() -> f64 { 1e-12 }
}
#[allow(non_camel_case_types)]
#[cfg(feature="cgemm")]
pub type c32 = [f32; 2];
#[allow(non_camel_case_types)]
#[cfg(feature="cgemm")]
pub type c64 = [f64; 2];
#[cfg(feature="cgemm")]
impl Float for c32 {
fn zero() -> Self { [0., 0.] }
fn one() -> Self { [1., 0.] }
fn from(x: i64) -> Self { [x as _, 0.] }
fn from2(x: i64, y: i64) -> Self { [x as _, y as _] }
fn nan() -> Self { [0./0., 0./0.] }
fn real(self) -> Self { [self[0], 0.] }
fn imag(self) -> Self { [self[1], 0.] }
fn is_nan(self) -> bool { self[0].is_nan() || self[1].is_nan() }
fn is_complex() -> bool { true }
fn diff(self, rhs: Self) -> Self {
[self[0] - rhs[0], self[1] - rhs[1]]
}
fn abs_f64(self) -> f64 {
(self[0].powi(2) + self[1].powi(2)).sqrt() as f64
}
fn relative_error_scale() -> f64 { 1e-6 }
}
#[cfg(feature="cgemm")]
impl Float for c64 {
fn zero() -> Self { [0., 0.] }
fn one() -> Self { [1., 0.] }
fn from(x: i64) -> Self { [x as _, 0.] }
fn from2(x: i64, y: i64) -> Self { [x as _, y as _] }
fn nan() -> Self { [0./0., 0./0.] }
fn real(self) -> Self { [self[0], 0.] }
fn imag(self) -> Self { [self[1], 0.] }
fn is_nan(self) -> bool { self[0].is_nan() || self[1].is_nan() }
fn is_complex() -> bool { true }
fn diff(self, rhs: Self) -> Self {
[self[0] - rhs[0], self[1] - rhs[1]]
}
fn abs_f64(self) -> f64 {
(self[0].powi(2) + self[1].powi(2)).sqrt()
}
fn relative_error_scale() -> f64 { 1e-12 }
}
pub trait Gemm : Sized {
unsafe fn gemm(
m: usize, k: usize, n: usize,
alpha: Self,
a: *const Self, rsa: isize, csa: isize,
b: *const Self, rsb: isize, csb: isize,
beta: Self,
c: *mut Self, rsc: isize, csc: isize);
}
impl Gemm for f32 {
unsafe fn gemm(
m: usize, k: usize, n: usize,
alpha: Self,
a: *const Self, rsa: isize, csa: isize,
b: *const Self, rsb: isize, csb: isize,
beta: Self,
c: *mut Self, rsc: isize, csc: isize) {
sgemm(
m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc)
}
}
impl Gemm for f64 {
unsafe fn gemm(
m: usize, k: usize, n: usize,
alpha: Self,
a: *const Self, rsa: isize, csa: isize,
b: *const Self, rsb: isize, csb: isize,
beta: Self,
c: *mut Self, rsc: isize, csc: isize) {
dgemm(
m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc)
}
}
#[cfg(feature="cgemm")]
impl Gemm for c32 {
unsafe fn gemm(
m: usize, k: usize, n: usize,
alpha: Self,
a: *const Self, rsa: isize, csa: isize,
b: *const Self, rsb: isize, csb: isize,
beta: Self,
c: *mut Self, rsc: isize, csc: isize) {
cgemm(
CGemmOption::Standard,
CGemmOption::Standard,
m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc)
}
}
#[cfg(feature="cgemm")]
impl Gemm for c64 {
unsafe fn gemm(
m: usize, k: usize, n: usize,
alpha: Self,
a: *const Self, rsa: isize, csa: isize,
b: *const Self, rsb: isize, csb: isize,
beta: Self,
c: *mut Self, rsc: isize, csc: isize) {
zgemm(
CGemmOption::Standard,
CGemmOption::Standard,
m, k, n,
alpha,
a, rsa, csa,
b, rsb, csb,
beta,
c, rsc, csc)
}
}