MLIR  21.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
15 #include "mlir/IR/BuiltinOps.h"
16 #include "mlir/IR/PatternMatch.h"
17 
18 using namespace mlir;
19 using namespace mlir::x86vector;
20 
21 namespace {
22 
23 /// Replaces an operation with a call to an LLVM intrinsic with the specified
24 /// name and operands.
25 ///
26 /// The rewrite performs a simple one-to-one matching between the op and LLVM
27 /// intrinsic. For example:
28 ///
29 /// ```mlir
30 /// %res = x86vector.op %val : vector<16xf32>
31 /// ```
32 ///
33 /// can be converted to
34 ///
35 /// ```mlir
36 /// %res = llvm.call_intrinsic "intrinsic"(%val)
37 /// ```
38 ///
39 /// The provided operands must be LLVM-compatible.
40 ///
41 /// Upholds a convention that multi-result operations get converted into an
42 /// operation returning the LLVM IR structure type, in which case individual
43 /// values are first extracted before replacing the original results.
44 LogicalResult intrinsicRewrite(Operation *op, StringAttr intrinsic,
45  ValueRange operands,
46  const LLVMTypeConverter &typeConverter,
47  PatternRewriter &rewriter) {
48  auto loc = op->getLoc();
49 
50  if (!llvm::all_of(operands, [](Value value) {
51  return LLVM::isCompatibleType(value.getType());
52  }))
53  return rewriter.notifyMatchFailure(op, "Expects LLVM-compatible types.");
54 
55  unsigned numResults = op->getNumResults();
56  Type resType;
57  if (numResults != 0)
58  resType = typeConverter.packOperationResults(op->getResultTypes());
59 
60  auto callIntrOp =
61  rewriter.create<LLVM::CallIntrinsicOp>(loc, resType, intrinsic, operands);
62  // Propagate attributes.
63  callIntrOp->setAttrs(op->getAttrDictionary());
64 
65  if (numResults <= 1) {
66  // Directly replace the original op.
67  rewriter.replaceOp(op, callIntrOp);
68  return success();
69  }
70 
71  // Extract individual results from packed structure and use them as
72  // replacements.
73  SmallVector<Value, 4> results;
74  results.reserve(numResults);
75  Value intrRes = callIntrOp.getResults();
76  for (unsigned i = 0; i < numResults; ++i) {
77  results.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, intrRes, i));
78  }
79  rewriter.replaceOp(op, results);
80 
81  return success();
82 }
83 
84 /// Generic one-to-one conversion of simply mappable operations into calls
85 /// to their respective LLVM intrinsics.
86 struct OneToOneIntrinsicOpConversion
87  : public OpInterfaceRewritePattern<x86vector::OneToOneIntrinsicOp> {
89  x86vector::OneToOneIntrinsicOp>::OpInterfaceRewritePattern;
90 
91  OneToOneIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
92  PatternBenefit benefit = 1)
93  : OpInterfaceRewritePattern(&typeConverter.getContext(), benefit),
94  typeConverter(typeConverter) {}
95 
96  LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
97  PatternRewriter &rewriter) const override {
98  return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
99  op.getIntrinsicOperands(rewriter, typeConverter),
100  typeConverter, rewriter);
101  }
102 
103 private:
104  const LLVMTypeConverter &typeConverter;
105 };
106 
107 } // namespace
108 
109 /// Populate the given list with patterns that convert from X86Vector to LLVM.
111  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
112  patterns.add<OneToOneIntrinsicOpConversion>(converter);
113 }
114 
116  LLVMConversionTarget &target) {
117  target.addIllegalOp<
118  MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op,
119  CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
120  CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>();
121 }
static MLIRContext * getContext(OpFoldResult val)
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Type packOperationResults(TypeRange types) const
Convert a non-empty list of types of values produced by an operation into an LLVM-compatible type.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
DictionaryAttr getAttrDictionary()
Return all of the attributes on this operation as a DictionaryAttr.
Definition: Operation.cpp:296
void setAttrs(DictionaryAttr newAttrs)
Set the attributes from a dictionary on this operation.
Definition: Operation.cpp:305
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
result_type_range getResultTypes()
Definition: Operation.h:428
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:753
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:686
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:793
Include the generated interface declarations.
void populateX86VectorLegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower X86Vector ops to ops that map to LLVM intrinsics.
const FrozenRewritePatternSet & patterns
void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering X86Vector ops to ops that map to LLVM intrinsics.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:334