MLIR  19.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utilities to support the ArmSME dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the ArmSME dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 
16 namespace mlir::arm_sme {
17 
19  assert(isValidSMETileElementType(type) && "invalid tile type!");
21 }
22 
24  return type.isInteger(8) || type.isInteger(16) || type.isInteger(32) ||
25  type.isInteger(64) || type.isInteger(128) || type.isF16() ||
26  type.isBF16() || type.isF32() || type.isF64() || type.isF128();
27 }
28 
29 bool isValidSMETileVectorType(VectorType vType) {
30  if ((vType.getRank() != 2) || !vType.allDimsScalable())
31  return false;
32 
33  auto elemType = vType.getElementType();
34  if (!isValidSMETileElementType(elemType))
35  return false;
36 
37  unsigned minNumElts = getSMETileSliceMinNumElts(elemType);
38  if (vType.getShape() != ArrayRef<int64_t>({minNumElts, minNumElts}))
39  return false;
40 
41  return true;
42 }
43 
44 std::optional<ArmSMETileType> getSMETileType(VectorType type) {
45  if (!isValidSMETileVectorType(type))
46  return {};
47  switch (type.getElementTypeBitWidth()) {
48  case 8:
49  return ArmSMETileType::ZAB;
50  case 16:
51  return ArmSMETileType::ZAH;
52  case 32:
53  return ArmSMETileType::ZAS;
54  case 64:
55  return ArmSMETileType::ZAD;
56  case 128:
57  return ArmSMETileType::ZAQ;
58  default:
59  llvm_unreachable("unknown SME tile type");
60  }
61 }
62 
64  auto tileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op);
65  if (!tileOp)
66  return success(); // Not a tile op (no need to check).
67  auto tileId = tileOp.getTileId();
68  if (!tileId)
69  return success(); // Not having a tile ID (yet) is okay.
70  if (!tileId.getType().isSignlessInteger(32))
71  return tileOp.emitOpError("tile ID should be a 32-bit signless integer");
72  return success();
73 }
74 
76  PatternRewriter &rewriter, Location loc, Value initTile,
77  std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) {
78  OpBuilder::InsertionGuard g(rewriter);
79  auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
80  auto minTileSlices = rewriter.create<arith::ConstantIndexOp>(
81  loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0));
82  auto vscale =
83  rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType());
84  auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
85  auto numTileSlices =
86  rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale);
87  auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step,
88  ValueRange{initTile});
89  rewriter.setInsertionPointToStart(forOp.getBody());
90  Value nextTile =
91  makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(),
92  /*currentTile=*/forOp.getRegionIterArg(0));
93  rewriter.create<scf::YieldOp>(loc, nextTile);
94  return forOp;
95 }
96 
97 bool isMultipleOfSMETileVectorType(VectorType vType) {
98  if (vType.getRank() != 2 || !vType.allDimsScalable())
99  return false;
100 
101  auto elementType = vType.getElementType();
102  if (!isValidSMETileElementType(elementType))
103  return false;
104 
105  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
106 
107  int64_t vectorRows = vType.getDimSize(0);
108  int64_t vectorCols = vType.getDimSize(1);
109 
110  return (vectorRows > minNumElts || vectorCols > minNumElts) &&
111  vectorRows % minNumElts == 0 && vectorCols % minNumElts == 0;
112 }
113 
114 VectorType getSMETileTypeForElement(Type elementType) {
115  unsigned minNumElts = getSMETileSliceMinNumElts(elementType);
116  return VectorType::get({minNumElts, minNumElts}, elementType, {true, true});
117 }
118 
119 } // namespace mlir::arm_sme
IndexType getIndexType()
Definition: Builders.cpp:71
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:433
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF64() const
Definition: Types.cpp:52
bool isF32() const
Definition: Types.cpp:51
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
bool isF128() const
Definition: Types.cpp:54
bool isF16() const
Definition: Types.cpp:49
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
bool isBF16() const
Definition: Types.cpp:48
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:92
std::optional< ArmSMETileType > getSMETileType(VectorType)
Returns the type of SME tile this vector type corresponds to, or none if the vector type does not fit...
Definition: Utils.cpp:44
VectorType getSMETileTypeForElement(Type elementType)
Creates a vector type for the SME tile of elementType.
Definition: Utils.cpp:114
unsigned getSMETileSliceMinNumElts(Type type)
Return minimum number of elements for the given element type in a vector of SVL bits.
Definition: Utils.cpp:18
bool isValidSMETileElementType(Type type)
Returns true if type is a valid element type for an SME tile or false otherwise.
Definition: Utils.cpp:23
bool isMultipleOfSMETileVectorType(VectorType vType)
Returns true if vType is a multiple of an SME tile size.
Definition: Utils.cpp:97
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition: Utils.cpp:75
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:29
LogicalResult verifyOperationHasValidTileId(Operation *)
Verifies the tile ID (if set) on this tile operation is valid.
Definition: Utils.cpp:63
constexpr unsigned MinStreamingVectorLengthInBits
Definition: Utils.h:31
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26