MLIR  22.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the ArmSME dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 namespace mlir::arm_sme {
16 
18  assert(isValidSMETileElementType(type) && "invalid tile type!");
20 }
21 
23  return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
24  type.isInteger(64) || type.isInteger(128) || type.isF16() ||
25  type.isBF16() || type.isF32() || type.isF64() || type.isF128();
26 }
27 
28 bool isValidSMETileVectorType(VectorType vType) {
29  if ((vType.getRank() != 2) || !vType.allDimsScalable())
30  return false;
31 
32  auto elemType = vType.getElementType();
33  if (!isValidSMETileElementType(elemType))
34  return false;
35 
36  unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
37  if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
38  return false;
39 
40  return true;
41 }
42 
43 std::optional<ArmSMETileType> getSMETileType(VectorType type) {
44  if (!isValidSMETileVectorType(type))
45  return {};
46  switch (type.getElementTypeBitWidth()) {
47  case 8:
48  return ArmSMETileType::ZAB;
49  case 16:
50  return ArmSMETileType::ZAH;
51  case 32:
52  return ArmSMETileType::ZAS;
53  case 64:
54  return ArmSMETileType::ZAD;
55  case 128:
56  return ArmSMETileType::ZAQ;
57  default:
58  llvm_unreachable("unknown SME tile type");
59  }
60 }
61 
63  auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
64  if (!tileOp)
65  return success(); // Not a tile op (no need to check).
66  auto tileId = tileOp.getTileId();
67  if (!tileId)
68  return success(); // Not having a tile ID (yet) is okay.
69  if (!tileId.getType().isSignlessInteger(32))
70  return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
71  return success();
72 }
73 
75  PatternRewriter &rewriter, Location loc, Value initTile,
76  std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
77  OpBuilder::InsertionGuard g(rewriter);
78  auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
79  auto minTileSlices = arith::ConstantIndexOp::create(
80  rewriter, loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
81  auto vscale =
82  vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
83  auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
84  auto numTileSlices =
85  arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
86  auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices,
87  step, ValueRange{initTile});
88  rewriter.setInsertionPointToStart(forOp.getBody());
89  Value nextTile =
90  makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
91  /*currentTile=*/forOp.getRegionIterArg(0));
92  scf::YieldOp::create(rewriter, loc, nextTile);
93  return forOp;
94 }
95 
96 bool isMultipleOfSMETileVectorType(VectorType vType) {
97  if (vType.getRank() != 2 || !vType.allDimsScalable())
98  return false;
99 
100  auto elementType = vType.getElementType();
101  if (!isValidSMETileElementType(elementType))
102  return false;
103 
104  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
105 
106  int64_t vectorRows = vType.getDimSize(0);
107  int64_t vectorCols = vType.getDimSize(1);
108 
109  return (vectorRows > minNumElts || vectorCols > minNumElts) &&
110  vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
111 }
112 
113 VectorType getSMETileTypeForElement(Type elementType) {
114  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
115  return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
116 }
117 
119  FunctionOpInterface function) {
120  SmallVector<Operation *> worklist;
121  function->walk([&](Operation *op) {
122  auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op);
123  if (armSMEOp && isOpTriviallyDead(armSMEOp))
124  worklist.push_back(armSMEOp);
125  });
126  while (!worklist.empty()) {
127  Operation *op = worklist.pop_back_val();
128  if (!isOpTriviallyDead(op))
129  continue;
130  for (Value value : op->getOperands()) {
131  if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
132  worklist.push_back(armSMEOp);
133  }
134  rewriter.eraseOp(op);
135  }
136 }
137 
138 bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp) {
139  return tileOp && tileOp->getNumResults() == 1 &&
140  tileOp->getNumOperands() == 0 && isPure(tileOp);
141 }
142 
143 bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) {
144  for (Value result : tileOp->getResults()) {
145  if (arm_sme::isValidSMETileVectorType(result.getType()))
146  return true;
147  }
148  return false;
149 }
150 
151 OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
152  if (!tileOp)
153  return nullptr;
154  auto isTileOperandType = [](OpOperand &operand) {
155  return arm_sme::isValidSMETileVectorType(operand.get().getType());
156  };
157  assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
158  "expected at most one tile operand");
159  OpOperand *tileOperand =
160  llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
161  if (tileOperand == tileOp->getOpOperands().end())
162  return nullptr;
163  return tileOperand;
164 }
165 
166 bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB) {
167  // Note: This is <= due to how tile types are numbered in ArmSMEOps.td.
168  return static_cast<unsigned>(typeA) <= static_cast<unsigned>(typeB);
169 }
170 
171 } // namespace mlir::arm_sme
IndexType getIndexType()
Definition: Builders.cpp:50
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:764
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF128() const
Definition: Types.cpp:43
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition: Utils.cpp:43
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
Definition: Utils.cpp:118
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
Definition: Utils.cpp:113
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:17
bool isValidSMETileElementType(Type type)
Returns true if type is a valid element type for an SME tile or false otherwise.
Definition: Utils.cpp:22
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
Definition: Utils.cpp:96
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
Definition: Utils.cpp:138
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition: Utils.cpp:74
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
Definition: Utils.cpp:166
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:28
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
Definition: Utils.cpp:151
LogicalResult verifyOperationHasValidTileId(Operation *)
Verifies the tile ID (if set) on this tile operation is valid.
Definition: Utils.cpp:62
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
Definition: Utils.cpp:143
constexpr unsigned MinStreamingVectorLengthInBits
Definition: Utils.h:33
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...