MLIR 23.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the ArmSME dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15namespace mlir::arm_sme {
16
17unsigned getSizeInBytes(TypeSize type) {
18 switch (type) {
19 case arm_sme::TypeSize::Byte:
20 return 1;
21 case arm_sme::TypeSize::Half:
22 return 2;
23 case arm_sme::TypeSize::Word:
24 return 4;
25 case arm_sme::TypeSize::Double:
26 return 8;
27 }
28 llvm_unreachable("unknown type size");
29 return 0;
30}
31
33 assert(isValidSMETileElementType(type) && "invalid tile type!");
35}
36
38 return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
39 type.isInteger(64) || type.isInteger(128) || type.isF16() ||
40 type.isBF16() || type.isF32() || type.isF64() || type.isF128();
41}
42
43bool isValidSMETileVectorType(VectorType vType) {
44 if ((vType.getRank() != 2) || !vType.allDimsScalable())
45 return false;
46
47 auto elemType = vType.getElementType();
48 if (!isValidSMETileElementType(elemType))
49 return false;
50
51 unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
52 return vType.getShape() == ArrayRef<int64_t>({minNumElts, minNumElts});
53}
54
55std::optional<ArmSMETileType> getSMETileType(VectorType type) {
56 if (!isValidSMETileVectorType(type))
57 return {};
58 switch (type.getElementTypeBitWidth()) {
59 case 8:
60 return ArmSMETileType::ZAB;
61 case 16:
62 return ArmSMETileType::ZAH;
63 case 32:
64 return ArmSMETileType::ZAS;
65 case 64:
66 return ArmSMETileType::ZAD;
67 case 128:
68 return ArmSMETileType::ZAQ;
69 default:
70 llvm_unreachable("unknown SME tile type");
71 }
72}
73
75 auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
76 if (!tileOp)
77 return success(); // Not a tile op (no need to check).
78 auto tileId = tileOp.getTileId();
79 if (!tileId)
80 return success(); // Not having a tile ID (yet) is okay.
81 if (!tileId.getType().isSignlessInteger(32))
82 return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
83 return success();
84}
85
87 PatternRewriter &rewriter, Location loc, Value initTile,
88 std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
89 OpBuilder::InsertionGuard g(rewriter);
90 auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
91 auto minTileSlices = arith::ConstantIndexOp::create(
92 rewriter, loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
93 auto vscale =
94 vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
95 auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
96 auto numTileSlices =
97 arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
98 auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices,
99 step, ValueRange{initTile});
100 rewriter.setInsertionPointToStart(forOp.getBody());
101 Value nextTile =
102 makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
103 /*currentTile=*/forOp.getRegionIterArg(0));
104 scf::YieldOp::create(rewriter, loc, nextTile);
105 return forOp;
106}
107
108bool isMultipleOfSMETileVectorType(VectorType vType) {
109 if (vType.getRank() != 2 || !vType.allDimsScalable())
110 return false;
111
112 auto elementType = vType.getElementType();
113 if (!isValidSMETileElementType(elementType))
114 return false;
115
116 unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
117
118 int64_t vectorRows = vType.getDimSize(0);
119 int64_t vectorCols = vType.getDimSize(1);
120
121 return (vectorRows > minNumElts || vectorCols > minNumElts) &&
122 vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
123}
124
125VectorType getSMETileTypeForElement(Type elementType) {
126 unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
127 return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
128}
129
131 FunctionOpInterface function) {
133 function->walk([&](Operation *op) {
134 auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op);
135 if (armSMEOp && isOpTriviallyDead(armSMEOp))
136 worklist.push_back(armSMEOp);
137 });
138 while (!worklist.empty()) {
139 Operation *op = worklist.pop_back_val();
140 if (!isOpTriviallyDead(op))
141 continue;
142 for (Value value : op->getOperands()) {
143 if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
144 worklist.push_back(armSMEOp);
145 }
146 rewriter.eraseOp(op);
147 }
148}
149
151 return tileOp && tileOp->getNumResults() == 1 &&
152 tileOp->getNumOperands() == 0 && isPure(tileOp);
153}
154
156 for (Value result : tileOp->getResults()) {
158 return true;
159 }
160 return false;
161}
162
164 if (!tileOp)
165 return nullptr;
166 auto isTileOperandType = [](OpOperand &operand) {
167 return arm_sme::isValidSMETileVectorType(operand.get().getType());
168 };
169 assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
170 "expected at most one tile operand");
171 OpOperand *tileOperand =
172 llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
173 if (tileOperand == tileOp->getOpOperands().end())
174 return nullptr;
175 return tileOperand;
176}
177
178bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB) {
179 // Note: This is <= due to how tile types are numbered in ArmSMEOps.td.
180 return static_cast<unsigned>(typeA) <= static_cast<unsigned>(typeB);
181}
182
183} // namespace mlir::arm_sme
return success()
IndexType getIndexType()
Definition Builders.cpp:55
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isF128() const
Definition Types.cpp:43
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition Utils.cpp:55
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
Definition Utils.cpp:130
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
Definition Utils.cpp:125
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition Utils.cpp:32
bool isValidSMETileElementType(Type type)
Returns true if type is a valid element type for an SME tile or false otherwise.
Definition Utils.cpp:37
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
Definition Utils.cpp:108
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
Definition Utils.cpp:150
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition Utils.cpp:86
unsigned getSizeInBytes(TypeSize type)
Return the size represented by arm_sme::TypeSize in bytes.
Definition Utils.cpp:17
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
Definition Utils.cpp:178
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition Utils.cpp:43
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
Definition Utils.cpp:163
LogicalResult verifyOperationHasValidTileId(Operation *)
Verifies the tile ID (if set) on this tile operation is valid.
Definition Utils.cpp:74
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
Definition Utils.cpp:155
constexpr unsigned MinStreamingVectorLengthInBits
Definition Utils.h:33
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.