23 template <
typename OpTy>
28 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
30 if (adaptor.getOperands().getTypes() == op->getOperands().getTypes())
34 [&]() { op->setOperands(adaptor.getOperands()); });
45 ScalableMaskedAddIIntrOp>;
48 ScalableMaskedAddFIntrOp>;
51 ScalableMaskedSubIIntrOp>;
54 ScalableMaskedSubFIntrOp>;
57 ScalableMaskedMulIIntrOp>;
60 ScalableMaskedMulFIntrOp>;
63 ScalableMaskedSDivIIntrOp>;
66 ScalableMaskedUDivIIntrOp>;
69 ScalableMaskedDivFIntrOp>;
92 template <
typename Op,
typename IntrOp>
97 matchAndRewrite(
Op convertOp,
typename Op::Adaptor,
99 auto loc = convertOp.
getLoc();
101 auto source = convertOp.getSource();
102 VectorType sourceType = source.getType();
103 VectorType resultType = convertOp.getResult().getType();
106 loc, resultType, rewriter.
getZeroAttr(resultType));
112 tileShape.back() = sourceType.getShape().back();
117 auto extractOrInsertPosition =
ArrayRef(index).drop_back();
118 auto sourceVector = rewriter.
create<vector::ExtractOp>(
119 loc, source, extractOrInsertPosition);
120 VectorType convertedType =
122 .
setDim(0, resultType.getShape().back());
123 auto convertedVector =
125 result = rewriter.
create<vector::InsertOp>(loc, convertedVector, result,
126 extractOrInsertPosition);
134 using ConvertToSvboolOpLowering =
135 SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
137 using ConvertFromSvboolOpLowering =
138 SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
149 matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
152 auto loc = pselOp.getLoc();
153 auto svboolP1 = rewriter.
create<ConvertToSvboolIntrOp>(loc, svboolType,
155 auto indexI32 = rewriter.
create<arith::IndexCastOp>(
156 loc, rewriter.
getI32Type(), pselOp.getIndex());
157 auto pselIntr = rewriter.
create<PselIntrOp>(loc, svboolType, svboolP1,
158 pselOp.getP2(), indexI32);
160 pselOp, adaptor.getP1().getType(), pselIntr);
172 struct CreateMaskOpLowering
177 matchAndRewrite(vector::CreateMaskOp createMaskOp,
178 vector::CreateMaskOp::Adaptor adaptor,
180 auto maskType = createMaskOp.getVectorType();
181 if (maskType.getRank() != 1 || !maskType.isScalable())
185 auto maskBaseSize = maskType.getDimSize(0);
186 if (maskBaseSize < 2 || maskBaseSize > 16 ||
187 !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
189 "not SVE predicate-sized");
191 auto loc = createMaskOp.getLoc();
194 adaptor.getOperands()[0]);
224 ConvertToSvboolOpLowering,
225 ConvertFromSvboolOpLowering,
228 PselOpLowering>(converter);
231 patterns.
add<CreateMaskOpLowering>(converter, 4096);
242 ScalableMaskedAddIIntrOp,
243 ScalableMaskedAddFIntrOp,
244 ScalableMaskedSubIIntrOp,
245 ScalableMaskedSubFIntrOp,
246 ScalableMaskedMulIIntrOp,
247 ScalableMaskedMulFIntrOp,
248 ScalableMaskedSDivIIntrOp,
249 ScalableMaskedUDivIIntrOp,
250 ScalableMaskedDivFIntrOp,
251 ConvertToSvboolIntrOp,
252 ConvertFromSvboolIntrOp,
261 ScalableMaskedAddIOp,
262 ScalableMaskedAddFOp,
263 ScalableMaskedSubIOp,
264 ScalableMaskedSubFOp,
265 ScalableMaskedMulIOp,
266 ScalableMaskedMulFOp,
267 ScalableMaskedSDivIOp,
268 ScalableMaskedUDivIOp,
269 ScalableMaskedDivFOp,
TypedAttr getZeroAttr(Type type)
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
MLIRContext & getContext() const
Returns the MLIR context.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This is a builder type that keeps local references to arguments.
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
Include the generated interface declarations.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateArmSVELegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics.