MLIR 22.0.0git
Utils.cpp
Go to the documentation of this file.
1//===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for the ArmSME dialect.
10//
11//===----------------------------------------------------------------------===//
12
14
15namespace mlir::arm_sme {
16
17unsigned getSizeInBytes(TypeSize type) {
18 switch (type) {
19 case arm_sme::TypeSize::Byte:
20 return 1;
21 case arm_sme::TypeSize::Half:
22 return 2;
23 case arm_sme::TypeSize::Word:
24 return 4;
25 case arm_sme::TypeSize::Double:
26 return 8;
27 }
28 llvm_unreachable("unknown type size");
29 return 0;
30}
31
33 assert(isValidSMETileElementType(type) && "invalid tile type!");
35}
36
38 return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
39 type.isInteger(64) || type.isInteger(128) || type.isF16() ||
40 type.isBF16() || type.isF32() || type.isF64() || type.isF128();
41}
42
43bool isValidSMETileVectorType(VectorType vType) {
44 if ((vType.getRank() != 2) || !vType.allDimsScalable())
45 return false;
46
47 auto elemType = vType.getElementType();
48 if (!isValidSMETileElementType(elemType))
49 return false;
50
51 unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
52 if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
53 return false;
54
55 return true;
56}
57
58std::optional<ArmSMETileType> getSMETileType(VectorType type) {
59 if (!isValidSMETileVectorType(type))
60 return {};
61 switch (type.getElementTypeBitWidth()) {
62 case 8:
63 return ArmSMETileType::ZAB;
64 case 16:
65 return ArmSMETileType::ZAH;
66 case 32:
67 return ArmSMETileType::ZAS;
68 case 64:
69 return ArmSMETileType::ZAD;
70 case 128:
71 return ArmSMETileType::ZAQ;
72 default:
73 llvm_unreachable("unknown SME tile type");
74 }
75}
76
78 auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
79 if (!tileOp)
80 return success(); // Not a tile op (no need to check).
81 auto tileId = tileOp.getTileId();
82 if (!tileId)
83 return success(); // Not having a tile ID (yet) is okay.
84 if (!tileId.getType().isSignlessInteger(32))
85 return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
86 return success();
87}
88
90 PatternRewriter &rewriter, Location loc, Value initTile,
91 std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
92 OpBuilder::InsertionGuard g(rewriter);
93 auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
94 auto minTileSlices = arith::ConstantIndexOp::create(
95 rewriter, loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
96 auto vscale =
97 vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
98 auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
99 auto numTileSlices =
100 arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
101 auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices,
102 step, ValueRange{initTile});
103 rewriter.setInsertionPointToStart(forOp.getBody());
104 Value nextTile =
105 makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
106 /*currentTile=*/forOp.getRegionIterArg(0));
107 scf::YieldOp::create(rewriter, loc, nextTile);
108 return forOp;
109}
110
111bool isMultipleOfSMETileVectorType(VectorType vType) {
112 if (vType.getRank() != 2 || !vType.allDimsScalable())
113 return false;
114
115 auto elementType = vType.getElementType();
116 if (!isValidSMETileElementType(elementType))
117 return false;
118
119 unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
120
121 int64_t vectorRows = vType.getDimSize(0);
122 int64_t vectorCols = vType.getDimSize(1);
123
124 return (vectorRows > minNumElts || vectorCols > minNumElts) &&
125 vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
126}
127
128VectorType getSMETileTypeForElement(Type elementType) {
129 unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
130 return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
131}
132
134 FunctionOpInterface function) {
136 function->walk([&](Operation *op) {
137 auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op);
138 if (armSMEOp && isOpTriviallyDead(armSMEOp))
139 worklist.push_back(armSMEOp);
140 });
141 while (!worklist.empty()) {
142 Operation *op = worklist.pop_back_val();
143 if (!isOpTriviallyDead(op))
144 continue;
145 for (Value value : op->getOperands()) {
146 if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
147 worklist.push_back(armSMEOp);
148 }
149 rewriter.eraseOp(op);
150 }
151}
152
154 return tileOp && tileOp->getNumResults() == 1 &&
155 tileOp->getNumOperands() == 0 && isPure(tileOp);
156}
157
159 for (Value result : tileOp->getResults()) {
161 return true;
162 }
163 return false;
164}
165
167 if (!tileOp)
168 return nullptr;
169 auto isTileOperandType = [](OpOperand &operand) {
170 return arm_sme::isValidSMETileVectorType(operand.get().getType());
171 };
172 assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
173 "expected at most one tile operand");
174 OpOperand *tileOperand =
175 llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
176 if (tileOperand == tileOp->getOpOperands().end())
177 return nullptr;
178 return tileOperand;
179}
180
181bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB) {
182 // Note: This is <= due to how tile types are numbered in ArmSMEOps.td.
183 return static_cast<unsigned>(typeA) <= static_cast<unsigned>(typeB);
184}
185
186} // namespace mlir::arm_sme
return success()
IndexType getIndexType()
Definition Builders.cpp:51
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF64() const
Definition Types.cpp:41
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
bool isF128() const
Definition Types.cpp:43
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
bool isBF16() const
Definition Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition Utils.cpp:58
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
Definition Utils.cpp:133
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
Definition Utils.cpp:128
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition Utils.cpp:32
bool isValidSMETileElementType(Type type)
Returns true if type is a valid element type for an SME tile or false otherwise.
Definition Utils.cpp:37
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
Definition Utils.cpp:111
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
Definition Utils.cpp:153
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition Utils.cpp:89
unsigned getSizeInBytes(TypeSize type)
Return the size represented by arm_sme::TypeSize in bytes.
Definition Utils.cpp:17
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
Definition Utils.cpp:181
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition Utils.cpp:43
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
Definition Utils.cpp:166
LogicalResult verifyOperationHasValidTileId(Operation *)
Verifies the tile ID (if set) on this tile operation is valid.
Definition Utils.cpp:77
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
Definition Utils.cpp:158
constexpr unsigned MinStreamingVectorLengthInBits
Definition Utils.h:33
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.