MLIR  20.0.0git
ArithToArmSME.cpp
Go to the documentation of this file.
1 //===- ArithToArmSME.cpp - Arith to ArmSME dialect conversion -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
14 #include "mlir/Pass/Pass.h"
16 
17 namespace mlir {
18 #define GEN_PASS_DEF_ARITHTOARMSMECONVERSIONPASS
19 #include "mlir/Conversion/Passes.h.inc"
20 } // namespace mlir
21 
22 #define DEBUG_TYPE "arith-to-arm-sme"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Conversion helpers
28 //===----------------------------------------------------------------------===//
29 
30 /// Returns true if 'val' is a splat of zero, false otherwise.
31 static bool isSplatZero(Type elemType, DenseElementsAttr val) {
32  if (llvm::isa<FloatType>(elemType))
33  return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
34  if (llvm::isa<IntegerType>(elemType))
35  return val && val.isSplat() && val.getSplatValue<APInt>().isZero();
36  return false;
37 }
38 
39 namespace {
40 
41 //===----------------------------------------------------------------------===//
42 // ConstantOp
43 //===----------------------------------------------------------------------===//
44 
45 /// Conversion pattern for dense arith.constant.
46 struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
48 
49  LogicalResult matchAndRewrite(arith::ConstantOp constantOp,
50  PatternRewriter &rewriter) const final {
51  auto tileType = dyn_cast<VectorType>(constantOp.getType());
52  if (!tileType || !arm_sme::isValidSMETileVectorType(tileType))
53  return failure();
54 
55  auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
56  if (!denseAttr || !denseAttr.isSplat())
57  return failure();
58 
59  auto tileElementType = tileType.getElementType();
60 
61  // Lower 'arith.constant dense<0>' to 'arm_sme.zero' op.
62  if (isSplatZero(tileElementType, denseAttr)) {
63  rewriter.replaceOpWithNewOp<arm_sme::ZeroOp>(constantOp, tileType);
64  return success();
65  }
66 
67  // Lower non-zero constants to a loop of 'arm_sme.insert_tile_slice'
68  // ops that broadcast the constant to each tile slice.
69  auto loc = constantOp.getLoc();
70 
71  // To fill a tile with a constant, we create a 1-D splat of the constant,
72  // then move that into each tile slice (the largest unit we can set at once,
73  // outside of operations like the outerproduct).
74  VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
75  auto denseAttr1D = DenseElementsAttr::get(
76  tileSliceType, denseAttr.getSplatValue<Attribute>());
77  auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
78 
79  auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
80  auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
81  Value currentTile) {
82  // Create 'arm_sme.insert_tile_slice' to write vector to tile
83  // slice.
84  auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
85  loc, tileType, constantOp1D, currentTile, tileSliceIndex);
86  return nextTile.getResult();
87  };
89  rewriter, loc, initTile, makeLoopBody);
90  rewriter.replaceOp(constantOp, forOp.getResult(0));
91 
92  return success();
93  }
94 };
95 
96 } // namespace
97 
98 //===----------------------------------------------------------------------===//
99 // Pattern population
100 //===----------------------------------------------------------------------===//
101 
103  RewritePatternSet &patterns) {
104  patterns.add<ConstantOpToArmSMELowering>(patterns.getContext());
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // Pass definition
109 //===----------------------------------------------------------------------===//
110 
111 namespace {
112 struct ArithToArmSMEConversionPass final
113  : impl::ArithToArmSMEConversionPassBase<ArithToArmSMEConversionPass> {
114  using impl::ArithToArmSMEConversionPassBase<
115  ArithToArmSMEConversionPass>::ArithToArmSMEConversionPassBase;
116 
117  void runOnOperation() override {
118  RewritePatternSet patterns(&getContext());
120  if (failed(
121  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
122  return signalPassFailure();
123  }
124 };
125 } // namespace
static bool isSplatZero(Type elemType, DenseElementsAttr val)
Returns true if 'val' is a splat of zero, false otherwise.
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
An attribute that represents a reference to a dense vector or tensor object.
std::enable_if_t<!std::is_base_of< Attribute, T >::value||std::is_same< Attribute, T >::value, T > getSplatValue() const
Return the splat value for this attribute.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
This class helps build Operations.
Definition: Builders.h:215
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:488
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:823
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:847
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:317
Builder & dropDim(unsigned pos)
Erase a dim from shape @pos.
Definition: BuiltinTypes.h:342
void populateArithToArmSMEConversionPatterns(RewritePatternSet &patterns)
scf::ForOp createLoopOverTileSlices(PatternRewriter &rewriter, Location loc, Value initTile, std::function< Value(OpBuilder &, Location, Value, Value)> makeLoopBody)
Generates a for loop over ZA tile slices where the induction variable is the tile slice index and eac...
Definition: Utils.cpp:75
bool isValidSMETileVectorType(VectorType vType)
Returns true if vType is a valid vector type for an SME tile or false otherwise.
Definition: Utils.cpp:29
Include the generated interface declarations.
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358