48 #include "llvm/ADT/TypeSwitch.h"
50 #define DEBUG_TYPE "allocate-arm-sme-tiles"
54 #define GEN_PASS_DEF_TILEALLOCATION
55 #include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc"
64 static constexpr StringLiteral kTilesInUseAttr(
"arm_sme.tiles_in_use");
65 static constexpr StringLiteral
66 kNextInMemoryTileIdAttr(
"arm_sme.next_in_memory_tile_id");
68 enum class TileMask : unsigned {
109 LLVM_MARK_AS_BITMASK_ENUM(kZA0B)
114 static constexpr std::array ZA_B_MASKS = {TileMask::kZA0B};
115 static constexpr std::array ZA_H_MASKS = {TileMask::kZA0H, TileMask::kZA1H};
116 static constexpr std::array ZA_S_MASKS = {TileMask::kZA0S, TileMask::kZA1S,
117 TileMask::kZA2S, TileMask::kZA3S};
118 static constexpr std::array ZA_D_MASKS = {
119 TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D,
120 TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D};
121 static constexpr std::array ZA_Q_MASKS = {
122 TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q,
123 TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q,
124 TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q,
125 TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q};
127 case ArmSMETileType::ZAB:
129 case ArmSMETileType::ZAH:
131 case ArmSMETileType::ZAS:
133 case ArmSMETileType::ZAD:
135 case ArmSMETileType::ZAQ:
143 TileMask &tilesInUse) {
144 auto masks = getMasks(tileType);
146 if ((tilesInUse & tileMask) == TileMask::kNone) {
147 tilesInUse |= tileMask;
159 static void findDependantOps(
Value rootValue,
161 auto traverseCorrespondingValues = [&](
auto inputValues,
auto exitValues) {
163 if (value == rootValue)
164 findDependantOps(exitValues[idx], dependantOps);
168 if (dependantOps.contains(user))
170 dependantOps.insert(user);
172 .Case<cf::BranchOp>([&](
auto branchOp) {
174 traverseCorrespondingValues(branchOp.getDestOperands(),
175 branchOp.getDest()->getArguments());
177 .Case<cf::CondBranchOp>([&](
auto condBranchOp) {
179 traverseCorrespondingValues(
180 condBranchOp.getTrueOperands(),
181 condBranchOp.getTrueDest()->getArguments());
183 traverseCorrespondingValues(
184 condBranchOp.getFalseOperands(),
185 condBranchOp.getFalseDest()->getArguments());
187 .Case<LoopLikeOpInterface>([&](
auto loopOp) {
189 traverseCorrespondingValues(loopOp.getInits(),
190 loopOp.getRegionIterArgs());
192 .Case<scf::YieldOp>([&](
auto yieldOp) {
194 auto parent = user->getParentOp();
195 traverseCorrespondingValues(user->getOperands(),
196 parent->getResults());
200 for (
Value result : user->getResults())
201 findDependantOps(result, dependantOps);
205 struct AssignTileIDsPattern
210 if (tileOp.getTileId())
213 auto func = tileOp->getParentOfType<FunctionOpInterface>();
214 auto getDiscardableIntAttr = [&](StringRef name,
unsigned defaultVal = 0) {
215 if (
auto attr = llvm::dyn_cast_or_null<IntegerAttr>(
216 func->getDiscardableAttr(name)))
217 return unsigned(attr.getInt());
220 auto setDiscardableIntAttr = [&](StringRef name,
auto value) {
222 func->setDiscardableAttr(name,
227 std::optional<ArmSMETileType> tileType = tileOp.getAllocatedTileType();
231 TileMask tilesInUse =
232 static_cast<TileMask
>(getDiscardableIntAttr(kTilesInUseAttr));
233 auto tileId = allocateTileId(*tileType, tilesInUse);
234 bool tileIsInMemory =
failed(tileId);
235 if (tileIsInMemory) {
241 "failed to allocate SME virtual tile to operation, all tile "
242 "operations will go through memory, expect degraded performance");
261 findDependantOps(tileOp->getResult(0), dependantOps);
263 for (
auto *op : dependantOps) {
264 if (
auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
265 auto currentTileId = dependantTileOp.getTileId();
266 if (currentTileId &&
unsigned(currentTileId.getInt()) != tileId)
267 return dependantTileOp.emitOpError(
268 "already assigned different SME virtual tile!");
274 setDiscardableIntAttr(kTilesInUseAttr, tilesInUse);
276 setDiscardableIntAttr(kNextInMemoryTileIdAttr, *tileId + 1);
277 rewriter.
modifyOpInPlace(tileOp, [&] { tileOp.setTileId(tileIDAttr); });
278 for (
auto *op : dependantOps) {
279 if (
auto dependantTileOp = llvm::dyn_cast<ArmSMETileOpInterface>(op)) {
281 dependantTileOp, [&] { dependantTileOp.setTileId(tileIDAttr); });
289 struct TileAllocationPass
290 :
public arm_sme::impl::TileAllocationBase<TileAllocationPass> {
291 void runOnOperation()
override {
293 patterns.add<AssignTileIDsPattern>(patterns.getContext());
299 getOperation(), std::move(patterns), config))) {
307 return std::make_unique<TileAllocationPass>();
static MLIRContext * getContext(OpFoldResult val)
IntegerAttr getI32IntegerAttr(int32_t value)
This class provides support for representing a failure result, or a valid value of type T.
This class allows control over how the GreedyPatternRewriteDriver works.
bool useTopDownTraversal
This specifies the order of initial traversal that populates the rewriters worklist.
Operation is the basic unit of execution within MLIR.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
user_range getUsers() const
static constexpr unsigned kInMemoryTileIdBase
std::unique_ptr< Pass > createTileAllocationPass()
Pass that allocates tile IDs to ArmSME operations.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
LogicalResult applyPatternsAndFoldGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit=1)