MLIR  19.0.0git
JointMatrixOps.cpp
Go to the documentation of this file.
1 //===- JointMatrixOps.cpp - MLIR SPIR-V Intel Joint Matrix Ops -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Defines the Intel Joint Matrix operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 namespace mlir {
16 //===----------------------------------------------------------------------===//
17 // spirv.INTEL.JointMatrixLoad
18 //===----------------------------------------------------------------------===//
19 
20 static LogicalResult
21 verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
22  Type pointeeType = llvm::cast<spirv::PointerType>(pointer).getPointeeType();
23  if (!llvm::isa<spirv::ScalarType>(pointeeType) &&
24  !llvm::isa<VectorType>(pointeeType))
25  return op->emitError(
26  "Pointer must point to a scalar or vector type but provided ")
27  << pointeeType;
28  spirv::StorageClass storage =
29  llvm::cast<spirv::PointerType>(pointer).getStorageClass();
30  if (storage != spirv::StorageClass::Workgroup &&
31  storage != spirv::StorageClass::CrossWorkgroup &&
32  storage != spirv::StorageClass::UniformConstant &&
33  storage != spirv::StorageClass::Generic)
34  return op->emitError("Pointer storage class must be Workgroup or "
35  "CrossWorkgroup but provided ")
36  << stringifyStorageClass(storage);
37  return success();
38 }
39 
41  return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
42  getResult().getType());
43 }
44 
45 //===----------------------------------------------------------------------===//
46 // spirv.INTEL.JointMatrixStore
47 //===----------------------------------------------------------------------===//
48 
50  return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
51  getObject().getType());
52 }
53 
54 //===----------------------------------------------------------------------===//
55 // spirv.INTEL.JointMatrixMad
56 //===----------------------------------------------------------------------===//
57 
58 static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
59  if (op.getC().getType() != op.getResult().getType())
60  return op.emitOpError("result and third operand must have the same type");
61  auto typeA = llvm::cast<spirv::JointMatrixINTELType>(op.getA().getType());
62  auto typeB = llvm::cast<spirv::JointMatrixINTELType>(op.getB().getType());
63  auto typeC = llvm::cast<spirv::JointMatrixINTELType>(op.getC().getType());
64  auto typeR =
65  llvm::cast<spirv::JointMatrixINTELType>(op.getResult().getType());
66  if (typeA.getRows() != typeR.getRows() ||
67  typeA.getColumns() != typeB.getRows() ||
68  typeB.getColumns() != typeR.getColumns())
69  return op.emitOpError("matrix size must match");
70  if (typeR.getScope() != typeA.getScope() ||
71  typeR.getScope() != typeB.getScope() ||
72  typeR.getScope() != typeC.getScope())
73  return op.emitOpError("matrix scope must match");
74  if (typeA.getElementType() != typeB.getElementType() ||
75  typeR.getElementType() != typeC.getElementType())
76  return op.emitOpError("matrix element type must match");
77  return success();
78 }
79 
81  return verifyJointMatrixMad(*this);
82 }
83 
84 } // namespace mlir
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
Type getType() const
Return the type of this value.
Definition: Value.h:129
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
static LogicalResult verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix)
static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op)
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:421
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26