30 ScalableMaskedAddIIntrOp>;
33 ScalableMaskedAddFIntrOp>;
36 ScalableMaskedSubIIntrOp>;
39 ScalableMaskedSubFIntrOp>;
42 ScalableMaskedMulIIntrOp>;
45 ScalableMaskedMulFIntrOp>;
48 ScalableMaskedSDivIIntrOp>;
51 ScalableMaskedUDivIIntrOp>;
54 ScalableMaskedDivFIntrOp>;
77template <
typename Op,
typename IntrOp>
82 matchAndRewrite(
Op convertOp,
typename Op::Adaptor,
83 ConversionPatternRewriter &rewriter)
const override {
84 auto loc = convertOp.
getLoc();
86 auto source = convertOp.getSource();
87 VectorType sourceType = source.getType();
88 VectorType resultType = convertOp.getResult().getType();
90 Value result = arith::ConstantOp::create(rewriter, loc, resultType,
91 rewriter.getZeroAttr(resultType));
97 tileShape.back() = sourceType.getShape().back();
103 auto sourceVector = vector::ExtractOp::create(rewriter, loc, source,
104 extractOrInsertPosition);
105 VectorType convertedType =
107 .
setDim(0, resultType.getShape().back());
108 auto convertedVector =
109 IntrOp::create(rewriter, loc,
TypeRange{convertedType}, sourceVector);
110 result = vector::InsertOp::create(rewriter, loc, convertedVector,
result,
111 extractOrInsertPosition);
114 rewriter.replaceOp(convertOp,
result);
119using ConvertToSvboolOpLowering =
120 SvboolConversionOpLowering<ConvertToSvboolOp, ConvertToSvboolIntrOp>;
122using ConvertFromSvboolOpLowering =
123 SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
134 matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
135 ConversionPatternRewriter &rewriter)
const override {
136 auto svboolType = VectorType::get(16, rewriter.getI1Type(),
true);
137 auto loc = pselOp.getLoc();
138 auto svboolP1 = ConvertToSvboolIntrOp::create(rewriter, loc, svboolType,
140 auto indexI32 = arith::IndexCastOp::create(
141 rewriter, loc, rewriter.getI32Type(), pselOp.getIndex());
142 auto pselIntr = PselIntrOp::create(rewriter, loc, svboolType, svboolP1,
143 pselOp.getP2(), indexI32);
144 rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
145 pselOp, adaptor.getP1().
getType(), pselIntr);
157struct CreateMaskOpLowering
162 matchAndRewrite(vector::CreateMaskOp createMaskOp,
163 vector::CreateMaskOp::Adaptor adaptor,
164 ConversionPatternRewriter &rewriter)
const override {
165 auto maskType = createMaskOp.getVectorType();
166 if (maskType.getRank() != 1 || !maskType.isScalable())
167 return rewriter.notifyMatchFailure(createMaskOp,
"not 1-D and scalable");
170 auto maskBaseSize = maskType.getDimSize(0);
171 if (maskBaseSize < 2 || maskBaseSize > 16 ||
172 !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
173 return rewriter.notifyMatchFailure(createMaskOp,
174 "not SVE predicate-sized");
176 auto loc = createMaskOp.getLoc();
177 auto zero = LLVM::ZeroOp::create(rewriter, loc, rewriter.getI64Type());
178 rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
179 adaptor.getOperands()[0]);
192 patterns.add<ConvertFromSvboolOpLowering,
193 ConvertToSvboolOpLowering,
214 patterns.add<CreateMaskOpLowering>(converter, 4096);
221 target.addLegalOp<BfmmlaOp,
222 ConvertFromSvboolIntrOp,
223 ConvertToSvboolIntrOp,
226 ScalableMaskedAddFIntrOp,
227 ScalableMaskedAddIIntrOp,
228 ScalableMaskedDivFIntrOp,
229 ScalableMaskedMulFIntrOp,
230 ScalableMaskedMulIIntrOp,
231 ScalableMaskedSDivIIntrOp,
232 ScalableMaskedSubFIntrOp,
233 ScalableMaskedSubIIntrOp,
234 ScalableMaskedUDivIIntrOp,
243 target.addIllegalOp<ConvertFromSvboolOp,
247 ScalableMaskedAddFOp,
248 ScalableMaskedAddIOp,
249 ScalableMaskedDivFOp,
250 ScalableMaskedMulFOp,
251 ScalableMaskedMulIOp,
252 ScalableMaskedSDivIOp,
253 ScalableMaskedSubFOp,
254 ScalableMaskedSubIOp,
255 ScalableMaskedUDivIOp,
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Generic implementation of one-to-one conversion from "SourceOp" to "TargetOp" where the latter belong...
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
This class provides an abstraction over the various different ranges of value types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
This is a builder type that keeps local references to arguments.
Builder & setDim(unsigned pos, int64_t val)
Set a dim in shape @pos to val.
Include the generated interface declarations.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering ArmSVE ops to ops that map to LLVM intrinsics.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
const FrozenRewritePatternSet & patterns
void populateArmSVELegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower ArmSVE ops to ops that map to LLVM intrinsics.