23 template <
typename OpTy>
25 return cast<VectorType>(op.getSrc().getType()).getElementType();
29 return cast<VectorType>(op.getA().getType()).getElementType();
38 template <
typename OpTy,
typename Intr32OpTy,
typename Intr64OpTy>
49 matchAndRewrite(OpTy op,
typename OpTy::Adaptor adaptor,
51 Type elementType = getSrcVectorElementType<OpTy>(op);
55 op, Intr32OpTy::getOperationName(), adaptor.getOperands(),
56 op->getAttrs(), getTypeConverter(), rewriter);
59 op, Intr64OpTy::getOperationName(), adaptor.getOperands(),
60 op->getAttrs(), getTypeConverter(), rewriter);
62 op,
"expected 'src' to be either f32 or f64");
66 struct MaskCompressOpConversion
71 matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
73 auto opType = adaptor.getA().getType();
77 src = adaptor.getSrc();
78 }
else if (op.getConstantSrc()) {
79 src = rewriter.
create<arith::ConstantOp>(op.getLoc(), opType,
80 op.getConstantSrcAttr());
83 src = rewriter.
create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
97 matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
99 auto opType = adaptor.getA().getType();
109 matchAndRewrite(DotOp op, OpAdaptor adaptor,
111 auto opType = adaptor.getA().getType();
116 rewriter.
create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
118 adaptor.getB(), scale);
125 template <
typename OpTy,
typename Intr32OpTy,
typename Intr64OpTy>
128 using Intr32Op = Intr32OpTy;
129 using Intr64Op = Intr64OpTy;
134 template <
typename... Args>
135 struct RegistryImpl {
141 .
add<LowerToIntrinsic<
typename Args::MainOp,
typename Args::Intr32Op,
142 typename Args::Intr64Op>...>(converter);
148 target.
addLegalOp<
typename Args::Intr32Op...>();
149 target.
addLegalOp<
typename Args::Intr64Op...>();
153 using Registry = RegistryImpl<
154 RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
155 RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
156 RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
163 Registry::registerPatterns(converter, patterns);
164 patterns.
add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
170 Registry::configureTarget(target);
static MLIRContext * getContext(OpFoldResult val)
TypedAttr getZeroAttr(Type type)
IntegerAttr getI8IntegerAttr(int8_t value)
This class implements a pattern rewriter for use with ConversionPatterns.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Include the generated interface declarations.
void populateX86VectorLegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower X86Vector ops to ops that map to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering X86Vector ops to ops that map to LLVM intrinsics.