MLIR  19.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::arm_sve;
22 
23 template <typename OpTy>
24 class ForwardOperands : public OpConversionPattern<OpTy> {
26 
28  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
29  ConversionPatternRewriter &rewriter) const final {
30  if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
31  return rewriter.notifyMatchFailure(op, "operand types already match");
32 
33  rewriter.modifyOpInPlace(op,
34  [&]() { op->setOperands(adaptor.getOperands()); });
35  return success();
36  }
37 };
38 
44  OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
45  ScalableMaskedAddIIntrOp>;
47  OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
48  ScalableMaskedAddFIntrOp>;
50  OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
51  ScalableMaskedSubIIntrOp>;
53  OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
54  ScalableMaskedSubFIntrOp>;
56  OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
57  ScalableMaskedMulIIntrOp>;
59  OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
60  ScalableMaskedMulFIntrOp>;
62  OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
63  ScalableMaskedSDivIIntrOp>;
65  OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
66  ScalableMaskedUDivIIntrOp>;
68  OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
69  ScalableMaskedDivFIntrOp>;
70 
71 namespace {
72 
73 /// Unrolls a conversion to/from equivalent vector types, to allow using a
74 /// conversion intrinsic that only supports 1-D vector types.
75 ///
76 /// Example:
77 /// ```
78 /// %result = arm_sve.convert_to_svbool %source : vector<2x[4]xi1>
79 /// ```
80 /// is rewritten into:
81 /// ```
82 /// %cst = arith.constant dense<false> : vector<2x[16]xi1>
83 /// %1 = vector.extract %source[0] : vector<[4]xi1> from vector<2x[4]xi1>
84 /// %2 = "arm_sve.intr.convert.to.svbool"(%1)
85 /// : (vector<[4]xi1>) -> vector<[16]xi1>
86 /// %3 = vector.insert %2, %cst[0] : vector<[16]xi1> into vector<2x[16]xi1>
87 /// %4 = vector.extract %source[1] : vector<[4]xi1> from vector<2x[4]xi1>
88 /// %5 = "arm_sve.intr.convert.to.svbool"(%4)
89 /// : (vector<[4]xi1>) -> vector<[16]xi1>
90 /// %result = vector.insert %5, %3[1] : vector<[16]xi1> into vector<2x[16]xi1>
91 /// ```
92 template <typename Op, typename IntrOp>
93 struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> {
95 
97  matchAndRewrite(Op convertOp, typename Op::Adaptor,
98  ConversionPatternRewriter &rewriter) const override {
99  auto loc = convertOp.getLoc();
100 
101  auto source = convertOp.getSource();
102  VectorType sourceType = source.getType();
103  VectorType resultType = convertOp.getResult().getType();
104 
105  Value result = rewriter.create<arith::ConstantOp>(
106  loc, resultType, rewriter.getZeroAttr(resultType));
107 
108  // We want to iterate over the input vector in steps of the trailing
109  // dimension. So this creates tile shape where all leading dimensions are 1,
110  // and the trailing dimension step is the size of the dimension.
111  SmallVector<int64_t> tileShape(sourceType.getRank(), 1);
112  tileShape.back() = sourceType.getShape().back();
113 
114  // Iterate over all scalable mask/predicate slices of the source vector.
115  for (SmallVector<int64_t> index :
116  StaticTileOffsetRange(sourceType.getShape(), tileShape)) {
117  auto extractOrInsertPosition = ArrayRef(index).drop_back();
118  auto sourceVector = rewriter.create<vector::ExtractOp>(
119  loc, source, extractOrInsertPosition);
120  VectorType convertedType =
121  VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType()))
122  .setDim(0, resultType.getShape().back());
123  auto convertedVector =
124  rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector);
125  result = rewriter.create<vector::InsertOp>(loc, convertedVector, result,
126  extractOrInsertPosition);
127  }
128 
129  rewriter.replaceOp(convertOp, result);
130  return success();
131  }
132 };
133 
134 using ConvertToSvboolOpLowering =
135  SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
136 
137 using ConvertFromSvboolOpLowering =
138  SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
139 
142 
143 } // namespace
144 
145 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
147  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
148  // Populate conversion patterns
149 
150  // clang-format off
154  &converter.getContext());
155  patterns.add<SdotOpLowering,
168  ConvertToSvboolOpLowering,
169  ConvertFromSvboolOpLowering,
170  ZipX2OpLowering,
171  ZipX4OpLowering>(converter);
172  // clang-format on
173 }
174 
176  LLVMConversionTarget &target) {
177  // clang-format off
178  target.addLegalOp<SdotIntrOp,
179  SmmlaIntrOp,
180  UdotIntrOp,
181  UmmlaIntrOp,
182  ScalableMaskedAddIIntrOp,
183  ScalableMaskedAddFIntrOp,
184  ScalableMaskedSubIIntrOp,
185  ScalableMaskedSubFIntrOp,
186  ScalableMaskedMulIIntrOp,
187  ScalableMaskedMulFIntrOp,
188  ScalableMaskedSDivIIntrOp,
189  ScalableMaskedUDivIIntrOp,
190  ScalableMaskedDivFIntrOp,
191  ConvertToSvboolIntrOp,
192  ConvertFromSvboolIntrOp,
193  ZipX2IntrOp,
194  ZipX4IntrOp>();
195  target.addIllegalOp<SdotOp,
196  SmmlaOp,
197  UdotOp,
198  UmmlaOp,
199  ScalableMaskedAddIOp,
200  ScalableMaskedAddFOp,
201  ScalableMaskedSubIOp,
202  ScalableMaskedSubFOp,
203  ScalableMaskedMulIOp,
204  ScalableMaskedMulFOp,
205  ScalableMaskedSDivIOp,
206  ScalableMaskedUDivIOp,
207  ScalableMaskedDivFOp,
208  ConvertToSvboolOp,
209  ConvertFromSvboolOp,
210  ZipX2Op,
211  ZipX4Op>();
212  // clang-format on
213 }
OneToOneConvertToLLVMPattern< UmmlaOp, UmmlaIntrOp > UmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubIOp, ScalableMaskedSubIIntrOp > ScalableMaskedSubIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSDivIOp, ScalableMaskedSDivIIntrOp > ScalableMaskedSDivIOpLowering
OneToOneConvertToLLVMPattern< SdotOp, SdotIntrOp > SdotOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddIOp, ScalableMaskedAddIIntrOp > ScalableMaskedAddIOpLowering
OneToOneConvertToLLVMPattern< SmmlaOp, SmmlaIntrOp > SmmlaOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedSubFOp, ScalableMaskedSubFIntrOp > ScalableMaskedSubFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedUDivIOp, ScalableMaskedUDivIIntrOp > ScalableMaskedUDivIOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedDivFOp, ScalableMaskedDivFIntrOp > ScalableMaskedDivFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulFOp, ScalableMaskedMulFIntrOp > ScalableMaskedMulFOpLowering
OneToOneConvertToLLVMPattern< UdotOp, UdotIntrOp > UdotOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedAddFOp, ScalableMaskedAddFIntrOp > ScalableMaskedAddFOpLowering
OneToOneConvertToLLVMPattern< ScalableMaskedMulIOp, ScalableMaskedMulIIntrOp > ScalableMaskedMulIOpLowering
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:34
MLIRContext & getContext() const
Returns the MLIR context.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Definition: Pattern.h:198
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Location getLoc()
The source location the operation was defined or derived from.
Definition: OpDefinition.h:125
This provides public APIs that all operations should have.
type_range getTypes() const
Definition: ValueRange.cpp:26
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
void setOperands(ValueRange operands)
Replace the current operands of this operation with the ones provided in 'operands'.
Definition: Operation.cpp:237
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
Definition: PatternMatch.h:630
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:305
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
Definition: BuiltinTypes.h:339
Include the generated interface declarations.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateArmSVELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26