MLIR  19.0.0git
ArithToArmSME.cpp
Go to the documentation of this file.
1 //===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/Pass/Pass.h"
16 
17 namespace mlir {
18 #define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
19 #include "mlir/Conversion/Passes.h.inc"
20 } // namespace mlir
21 
22 #define DEBUG_TYPE "arith-to-arm-sme"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Conversion helpers
28 //===----------------------------------------------------------------------===//
29 
30 /// Returns true if 'val' is a splat of zero, false otherwise.
31 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
32  if (llvm::isa<FloatType>(elemType))
33  return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
34  if (llvm::isa<IntegerType>(elemType))
35  return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
36  return false;
37 }
38 
39 namespace {
40 
41 //===----------------------------------------------------------------------===//
42 // ConstantOp
43 //===----------------------------------------------------------------------===//
44 
45 /// Conversion pattern for dense arith.constant.
46 struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
48 
49  LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
50  PatternRewriter &rewriter) const final {
51  auto tileType = dyn_cast<VectorType>(constantOp.getType());
52  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
53  return failure();
54 
55  auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
56  if (!denseAttr || !denseAttr.isSplat())
57  return failure();
58 
59  auto tileElementType = tileType.getElementType();
60 
61  // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
62  if (isSplatZero(tileElementType, denseAttr)) {
63  rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
64  return success();
65  }
66 
67  // Lower non-zero constants to a loop of 'arm_sme.move_vector_to_tile_slice'
68  // ops that broadcast the constant to each tile slice.
69  auto loc = constantOp.getLoc();
70 
71  // To fill a tile with a constant, we create a 1-D splat of the constant,
72  // then move that into each tile slice (the largest unit we can set at once,
73  // outside of operations like the outerproduct).
74  VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
75  auto denseAttr1D = DenseElementsAttr::get(
76  tileSliceType, denseAttr.getSplatValue<Attribute>());
77  auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
78 
79  auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
80  auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
81  Value currentTile) {
82  // Create 'arm_sme.move_vector_to_tile_slice' to write vector to tile
83  // slice.
84  auto nextTile = b.create<arm_sme::MoveVectorToTileSliceOp>(
85  loc, tileType, constantOp1D, currentTile, tileSliceIndex);
86  return nextTile.getResult();
87  };
89  rewriter, loc, initTile, makeLoopBody);
90  rewriter.replaceOp(constantOp, forOp.getResult(0));
91 
92  return success();
93  }
94 };
95 
96 } // namespace
97 
98 //===----------------------------------------------------------------------===//
99 // Pattern population
100 //===----------------------------------------------------------------------===//
101 
103  RewritePatternSet &patterns) {
104  patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // Pass definition
109 //===----------------------------------------------------------------------===//
110 
111 namespace {
112 struct ArithToArmSMEConversionPass final
113  : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
114  using impl::ArithToArmSMEConversionPassBase<
115  ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
116 
117  void runOnOperation() override {
118  RewritePatternSet patterns(&getContext());
120  if (failed(
121  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
122  return signalPassFailure();
123  }
124 };
125 } // namespace
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:209
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:330
void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns)
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition: Utils.cpp:75
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:29
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358