MLIR  22.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the ArmSME dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 namespace mlir::arm_sme {
16 
17 unsigned getSizeInBytes(TypeSize type) {
18  switch (type) {
19  case arm_sme::TypeSize::Byte:
20  return 1;
21  case arm_sme::TypeSize::Half:
22  return 2;
23  case arm_sme::TypeSize::Word:
24  return 4;
25  case arm_sme::TypeSize::Double:
26  return 8;
27  }
28  llvm_unreachable("unknown type size");
29  return 0;
30 }
31 
33  assert(isValidSMETileElementType(type) && "invalid tile type!");
35 }
36 
38  return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
39  type.isInteger(64) || type.isInteger(128) || type.isF16() ||
40  type.isBF16() || type.isF32() || type.isF64() || type.isF128();
41 }
42 
43 bool isValidSMETileVectorType(VectorType vType) {
44  if ((vType.getRank() != 2) || !vType.allDimsScalable())
45  return false;
46 
47  auto elemType = vType.getElementType();
48  if (!isValidSMETileElementType(elemType))
49  return false;
50 
51  unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
52  if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
53  return false;
54 
55  return true;
56 }
57 
58 std::optional<ArmSMETileType> getSMETileType(VectorType type) {
59  if (!isValidSMETileVectorType(type))
60  return {};
61  switch (type.getElementTypeBitWidth()) {
62  case 8:
63  return ArmSMETileType::ZAB;
64  case 16:
65  return ArmSMETileType::ZAH;
66  case 32:
67  return ArmSMETileType::ZAS;
68  case 64:
69  return ArmSMETileType::ZAD;
70  case 128:
71  return ArmSMETileType::ZAQ;
72  default:
73  llvm_unreachable("unknown SME tile type");
74  }
75 }
76 
78  auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
79  if (!tileOp)
80  return success(); // Not a tile op (no need to check).
81  auto tileId = tileOp.getTileId();
82  if (!tileId)
83  return success(); // Not having a tile ID (yet) is okay.
84  if (!tileId.getType().isSignlessInteger(32))
85  return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
86  return success();
87 }
88 
90  PatternRewriter &rewriter, Location loc, Value initTile,
91  std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
92  OpBuilder::InsertionGuard g(rewriter);
93  auto step = arith::ConstantIndexOp::create(rewriter, loc, 1);
94  auto minTileSlices = arith::ConstantIndexOp::create(
95  rewriter, loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
96  auto vscale =
97  vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType());
98  auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0);
99  auto numTileSlices =
100  arith::MulIOp::create(rewriter, loc, minTileSlices, vscale);
101  auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices,
102  step, ValueRange{initTile});
103  rewriter.setInsertionPointToStart(forOp.getBody());
104  Value nextTile =
105  makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
106  /*currentTile=*/forOp.getRegionIterArg(0));
107  scf::YieldOp::create(rewriter, loc, nextTile);
108  return forOp;
109 }
110 
111 bool isMultipleOfSMETileVectorType(VectorType vType) {
112  if (vType.getRank() != 2 || !vType.allDimsScalable())
113  return false;
114 
115  auto elementType = vType.getElementType();
116  if (!isValidSMETileElementType(elementType))
117  return false;
118 
119  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
120 
121  int64_t vectorRows = vType.getDimSize(0);
122  int64_t vectorCols = vType.getDimSize(1);
123 
124  return (vectorRows > minNumElts || vectorCols > minNumElts) &&
125  vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
126 }
127 
128 VectorType getSMETileTypeForElement(Type elementType) {
129  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
130  return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
131 }
132 
134  FunctionOpInterface function) {
135  SmallVector<Operation *> worklist;
136  function->walk([&](Operation *op) {
137  auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op);
138  if (armSMEOp && isOpTriviallyDead(armSMEOp))
139  worklist.push_back(armSMEOp);
140  });
141  while (!worklist.empty()) {
142  Operation *op = worklist.pop_back_val();
143  if (!isOpTriviallyDead(op))
144  continue;
145  for (Value value : op->getOperands()) {
146  if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
147  worklist.push_back(armSMEOp);
148  }
149  rewriter.eraseOp(op);
150  }
151 }
152 
153 bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp) {
154  return tileOp && tileOp->getNumResults() == 1 &&
155  tileOp->getNumOperands() == 0 && isPure(tileOp);
156 }
157 
158 bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) {
159  for (Value result : tileOp->getResults()) {
160  if (arm_sme::isValidSMETileVectorType(result.getType()))
161  return true;
162  }
163  return false;
164 }
165 
166 OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
167  if (!tileOp)
168  return nullptr;
169  auto isTileOperandType = [](OpOperand &operand) {
170  return arm_sme::isValidSMETileVectorType(operand.get().getType());
171  };
172  assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
173  "expected at most one tile operand");
174  OpOperand *tileOperand =
175  llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
176  if (tileOperand == tileOp->getOpOperands().end())
177  return nullptr;
178  return tileOperand;
179 }
180 
181 bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB) {
182  // Note: This is <= due to how tile types are numbered in ArmSMEOps.td.
183  return static_cast<unsigned>(typeA) <= static_cast<unsigned>(typeB);
184 }
185 
186 } // namespace mlir::arm_sme
IndexType getIndexType()
Definition: Builders.cpp:50
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
This class represents an operand of an operation.
Definition: Value.h:257
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:41
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF128() const
Definition: Types.cpp:43
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
bool isBF16() const
Definition: Types.cpp:37
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition: Utils.cpp:58
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
Definition: Utils.cpp:133
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
Definition: Utils.cpp:128
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:32
bool isValidSMETileElementType(Type type)
Returns true if type is a valid element type for an SME tile or false otherwise.
Definition: Utils.cpp:37
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
Definition: Utils.cpp:111
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
Definition: Utils.cpp:153
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition: Utils.cpp:89
unsigned getSizeInBytes(TypeSize type)
Return the size represented by arm_sme::TypeSize in bytes.
Definition: Utils.cpp:17
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
Definition: Utils.cpp:181
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:43
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
Definition: Utils.cpp:166
LogicalResult verifyOperationHasValidTileId(Operation *)
Verifies the tile ID (if set) on this tile operation is valid.
Definition: Utils.cpp:77
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
Definition: Utils.cpp:158
constexpr unsigned MinStreamingVectorLengthInBits
Definition: Utils.h:33
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...