MLIR 22.0.0git
VectorPattern.h
Go to the documentation of this file.
1//===- VectorPattern.h - Conversion pattern to the LLVM dialect -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
10#define MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
11
14
15namespace mlir {
16
17namespace LLVM {
18namespace detail {
19// Helper struct to "unroll" operations on n-D vectors in terms of operations on
20// 1-D LLVM vectors.
22 // LLVM array struct which encodes n-D vectors.
24 // LLVM vector type which encodes the inner 1-D vector type.
26 // Multiplicity of llvmNDVectorTy to llvm1DVectorTy.
28};
29
30// For >1-D vector types, extracts the necessary information to iterate over all
31// 1-D subvectors in the underlying llrepresentation of the n-D vector
32// Iterates on the llvm array type until we hit a non-array type (which is
33// asserted to be an llvm vector type).
34NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
35 const LLVMTypeConverter &converter);
36
37// Express `linearIndex` in terms of coordinates of `basis`.
38// Returns the empty vector when linearIndex is out of the range [0, P] where
39// P is the product of all the basis coordinates.
40//
41// Prerequisites:
42// Basis is an array of nonnegative integers (signed type inherited from
43// vector shape type).
45 unsigned linearIndex);
46
47// Iterate of linear index, convert to coords space and insert splatted 1-D
48// vector in each position.
49void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
51
53 Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
54 std::function<Value(Type, ValueRange)> createOperand,
55 ConversionPatternRewriter &rewriter);
56
57LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
58 ValueRange operands,
59 ArrayRef<NamedAttribute> targetAttrs,
60 Attribute propertiesAttr,
61 const LLVMTypeConverter &typeConverter,
62 ConversionPatternRewriter &rewriter);
63} // namespace detail
64} // namespace LLVM
65
66// Default attribute conversion class, which passes all source attributes
67// through to the target op, unmodified. The attribute to set properties of the
68// target operation will be nullptr (i.e. any properties that exist in will have
69// default values).
70template <typename SourceOp, typename TargetOp>
72public:
73 AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
74
75 ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
76 Attribute getPropAttr() const { return {}; }
77
78private:
80};
81
82/// Basic lowering implementation to rewrite Ops with just one result to the
83/// LLVM Dialect. This supports higher-dimensional vector types.
84/// The AttrConvert template template parameter should:
85// - be a template class with SourceOp and TargetOp type parameters
86// - have a constructor that takes a SourceOp instance
87// - a getAttrs() method that returns ArrayRef<NamedAttribute> containing
88// attributes that the target operation will have
89// - a getPropAttr() method that returns either a NULL attribute or a
90// DictionaryAttribute with properties that exist for the target operation
91template <typename SourceOp, typename TargetOp,
92 template <typename, typename> typename AttrConvert =
93 AttrConvertPassThrough,
94 bool FailOnUnsupportedFP = false>
96public:
99
100 /// Return the given type if it's a floating point type. If the given type is
101 /// a vector type, return its element type if it's a floating point type.
102 static FloatType getFloatingPointType(Type type) {
103 if (auto floatType = dyn_cast<FloatType>(type))
104 return floatType;
105 if (auto vecType = dyn_cast<VectorType>(type))
106 return dyn_cast<FloatType>(vecType.getElementType());
107 return nullptr;
108 }
109
110 LogicalResult
111 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
112 ConversionPatternRewriter &rewriter) const override {
113 static_assert(
114 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
115 "expected single result op");
116
117 // The pattern should not apply if a floating-point operand is converted to
118 // a non-floating-point type. This indicates that the floating point type
119 // is not supported by the LLVM lowering. (Such types are converted to
120 // integers.)
121 auto checkType = [&](Value v) -> LogicalResult {
122 FloatType floatType = getFloatingPointType(v.getType());
123 if (!floatType)
124 return success();
125 Type convertedType = this->getTypeConverter()->convertType(floatType);
126 if (!isa_and_nonnull<FloatType>(convertedType))
127 return rewriter.notifyMatchFailure(op,
128 "unsupported floating point type");
129 return success();
130 };
131 if (FailOnUnsupportedFP) {
132 for (Value operand : op->getOperands())
133 if (failed(checkType(operand)))
134 return failure();
135 if (failed(checkType(op->getResult(0))))
136 return failure();
137 }
138
139 // Determine attributes for the target op
140 AttrConvert<SourceOp, TargetOp> attrConvert(op);
141
143 op, TargetOp::getOperationName(), adaptor.getOperands(),
144 attrConvert.getAttrs(), attrConvert.getPropAttr(),
145 *this->getTypeConverter(), rewriter);
146 }
147};
148} // namespace mlir
149
150#endif // MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
return success()
AttrConvertPassThrough(SourceOp srcOp)
ArrayRef< NamedAttribute > getAttrs() const
Attribute getPropAttr() const
Attributes are known-constant values of operations.
Definition Attributes.h:25
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:213
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:27
Conversion from types to the LLVM IR dialect.
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
VectorConvertToLLVMPattern< SourceOp, TargetOp > Super
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
static FloatType getFloatingPointType(Type type)
Return the given type if it's a floating point type.
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, Attribute propertiesAttr, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< int64_t, 4 > arraySizes