MLIR  20.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the ArmSME dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 
16 namespace mlir::arm_sme {
17 
19  assert(isValidSMETileElementType(type) && "invalid tile type!");
21 }
22 
24  return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
25  type.isInteger(64) || type.isInteger(128) || type.isF16() ||
26  type.isBF16() || type.isF32() || type.isF64() || type.isF128();
27 }
28 
29 bool isValidSMETileVectorType(VectorType vType) {
30  if ((vType.getRank() != 2) || !vType.allDimsScalable())
31  return false;
32 
33  auto elemType = vType.getElementType();
34  if (!isValidSMETileElementType(elemType))
35  return false;
36 
37  unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
38  if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
39  return false;
40 
41  return true;
42 }
43 
44 std::optional<ArmSMETileType> getSMETileType(VectorType type) {
45  if (!isValidSMETileVectorType(type))
46  return {};
47  switch (type.getElementTypeBitWidth()) {
48  case 8:
49  return ArmSMETileType::ZAB;
50  case 16:
51  return ArmSMETileType::ZAH;
52  case 32:
53  return ArmSMETileType::ZAS;
54  case 64:
55  return ArmSMETileType::ZAD;
56  case 128:
57  return ArmSMETileType::ZAQ;
58  default:
59  llvm_unreachable("unknown SME tile type");
60  }
61 }
62 
64  auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
65  if (!tileOp)
66  return success(); // Not a tile op (no need to check).
67  auto tileId = tileOp.getTileId();
68  if (!tileId)
69  return success(); // Not having a tile ID (yet) is okay.
70  if (!tileId.getType().isSignlessInteger(32))
71  return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
72  return success();
73 }
74 
76  PatternRewriter &rewriter, Location loc, Value initTile,
77  std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
78  OpBuilder::InsertionGuard g(rewriter);
79  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
80  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
81  loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
82  auto vscale =
83  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
84  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
85  auto numTileSlices =
86  rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
87  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
88  ValueRange{initTile});
89  rewriter.setInsertionPointToStart(forOp.getBody());
90  Value nextTile =
91  makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
92  /*currentTile=*/forOp.getRegionIterArg(0));
93  rewriter.create<scf::YieldOp>(loc, nextTile);
94  return forOp;
95 }
96 
97 bool isMultipleOfSMETileVectorType(VectorType vType) {
98  if (vType.getRank() != 2 || !vType.allDimsScalable())
99  return false;
100 
101  auto elementType = vType.getElementType();
102  if (!isValidSMETileElementType(elementType))
103  return false;
104 
105  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
106 
107  int64_t vectorRows = vType.getDimSize(0);
108  int64_t vectorCols = vType.getDimSize(1);
109 
110  return (vectorRows > minNumElts || vectorCols > minNumElts) &&
111  vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
112 }
113 
114 VectorType getSMETileTypeForElement(Type elementType) {
115  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
116  return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
117 }
118 
120  FunctionOpInterface function) {
121  SmallVector<Operation *> worklist;
122  function->walk([&](Operation *op) {
123  auto armSMEOp = dyn_cast<arm_sme::ArmSMETileOpInterface>(op);
124  if (armSMEOp && isOpTriviallyDead(armSMEOp))
125  worklist.push_back(armSMEOp);
126  });
127  while (!worklist.empty()) {
128  Operation *op = worklist.pop_back_val();
129  if (!isOpTriviallyDead(op))
130  continue;
131  for (Value value : op->getOperands()) {
132  if (auto armSMEOp = value.getDefiningOp<arm_sme::ArmSMETileOpInterface>())
133  worklist.push_back(armSMEOp);
134  }
135  rewriter.eraseOp(op);
136  }
137 }
138 
139 bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp) {
140  return tileOp && tileOp->getNumResults() == 1 &&
141  tileOp->getNumOperands() == 0 && isPure(tileOp);
142 }
143 
144 bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp) {
145  for (Value result : tileOp->getResults()) {
146  if (arm_sme::isValidSMETileVectorType(result.getType()))
147  return true;
148  }
149  return false;
150 }
151 
152 OpOperand *getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp) {
153  if (!tileOp)
154  return nullptr;
155  auto isTileOperandType = [](OpOperand &operand) {
156  return arm_sme::isValidSMETileVectorType(operand.get().getType());
157  };
158  assert(llvm::count_if(tileOp->getOpOperands(), isTileOperandType) <= 1 &&
159  "expected at most one tile operand");
160  OpOperand *tileOperand =
161  llvm::find_if(tileOp->getOpOperands(), isTileOperandType);
162  if (tileOperand == tileOp->getOpOperands().end())
163  return nullptr;
164  return tileOperand;
165 }
166 
167 bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB) {
168  // Note: This is <= due to how tile types are numbered in ArmSMEOps.td.
169  return static_cast<unsigned>(typeA) <= static_cast<unsigned>(typeB);
170 }
171 
172 } // namespace mlir::arm_sme
IndexType getIndexType()
Definition: Builders.cpp:83
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:353
This class helps build Operations.
Definition: Builders.h:212
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:436
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:476
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:55
bool isF32() const
Definition: Types.cpp:54
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:61
bool isF128() const
Definition: Types.cpp:57
bool isF16() const
Definition: Types.cpp:52
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:128
bool isBF16() const
Definition: Types.cpp:51
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:93
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition: Utils.cpp:44
void eraseTriviallyDeadTileOps(IRRewriter &rewriter, FunctionOpInterface function)
Erase trivially dead tile ops from a function.
Definition: Utils.cpp:119
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
Definition: Utils.cpp:114
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:18
bool isValidSMETileElementType(Type type)
Returns true if type is a valid element type for an SME tile or false otherwise.
Definition: Utils.cpp:23
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
Definition: Utils.cpp:97
bool isTriviallyCloneableTileOp(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp is trivially cloneable.
Definition: Utils.cpp:139
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition: Utils.cpp:75
bool isTileTypeGreaterOrEqual(ArmSMETileType typeA, ArmSMETileType typeB)
Returns true typeA is >= (in terms of bytes) than typeB.
Definition: Utils.cpp:167
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:29
OpOperand * getTileOpOperand(arm_sme::ArmSMETileOpInterface tileOp)
Returns the tile OpOperand for this tileOp (or null).
Definition: Utils.cpp:152
LogicalResult verifyOperationHasValidTileId(Operation *)
Verifies the tile ID (if set) on this tile operation is valid.
Definition: Utils.cpp:63
bool hasTileResult(arm_sme::ArmSMETileOpInterface tileOp)
Returns true if tileOp produces a tile result.
Definition: Utils.cpp:144
constexpr unsigned MinStreamingVectorLengthInBits
Definition: Utils.h:33
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...