MLIR  14.0.0git
VectorPattern.cpp
Go to the documentation of this file.
1 //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
12 using namespace mlir;
13 
14 // For >1-D vector types, extracts the necessary information to iterate over all
15 // 1-D subvectors in the underlying llrepresentation of the n-D vector
16 // Iterates on the llvm array type until we hit a non-array type (which is
17 // asserted to be an llvm vector type).
20  LLVMTypeConverter &converter) {
21  assert(vectorType.getRank() > 1 && "expected >1D vector type");
22  NDVectorTypeInfo info;
23  info.llvmNDVectorTy = converter.convertType(vectorType);
24  if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) {
25  info.llvmNDVectorTy = nullptr;
26  return info;
27  }
28  info.arraySizes.reserve(vectorType.getRank() - 1);
29  auto llvmTy = info.llvmNDVectorTy;
30  while (llvmTy.isa<LLVM::LLVMArrayType>()) {
31  info.arraySizes.push_back(
32  llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
33  llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
34  }
35  if (!LLVM::isCompatibleVectorType(llvmTy))
36  return info;
37  info.llvm1DVectorTy = llvmTy;
38  return info;
39 }
40 
41 // Express `linearIndex` in terms of coordinates of `basis`.
42 // Returns the empty vector when linearIndex is out of the range [0, P] where
43 // P is the product of all the basis coordinates.
44 //
45 // Prerequisites:
46 // Basis is an array of nonnegative integers (signed type inherited from
47 // vector shape type).
49  unsigned linearIndex) {
51  res.reserve(basis.size());
52  for (unsigned basisElement : llvm::reverse(basis)) {
53  res.push_back(linearIndex % basisElement);
54  linearIndex = linearIndex / basisElement;
55  }
56  if (linearIndex > 0)
57  return {};
58  std::reverse(res.begin(), res.end());
59  return res;
60 }
61 
62 // Iterate of linear index, convert to coords space and insert splatted 1-D
63 // vector in each position.
65  OpBuilder &builder,
66  function_ref<void(ArrayAttr)> fun) {
67  unsigned ub = 1;
68  for (auto s : info.arraySizes)
69  ub *= s;
70  for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
71  auto coords = getCoordinates(info.arraySizes, linearIndex);
72  // Linear index is out of bounds, we are done.
73  if (coords.empty())
74  break;
75  assert(coords.size() == info.arraySizes.size());
76  auto position = builder.getI64ArrayAttr(coords);
77  fun(position);
78  }
79 }
80 
82  Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
83  std::function<Value(Type, ValueRange)> createOperand,
84  ConversionPatternRewriter &rewriter) {
85  auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
86 
87  SmallVector<Type> operand1DVectorTypes;
88  for (Value operand : op->getOperands()) {
89  auto operandNDVectorType = operand.getType().cast<VectorType>();
90  auto operandTypeInfo =
91  extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
92  operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
93  }
94  auto resultTypeInfo =
95  extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
96  auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
97  auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
98  auto loc = op->getLoc();
99  Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
100  nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) {
101  // For this unrolled `position` corresponding to the `linearIndex`^th
102  // element, extract operand vectors
103  SmallVector<Value, 4> extractedOperands;
104  for (const auto &operand : llvm::enumerate(operands)) {
105  extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
106  loc, operand1DVectorTypes[operand.index()], operand.value(),
107  position));
108  }
109  Value newVal = createOperand(result1DVectorTy, extractedOperands);
110  desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc,
111  newVal, position);
112  });
113  rewriter.replaceOp(op, desc);
114  return success();
115 }
116 
118  Operation *op, StringRef targetOp, ValueRange operands,
119  LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
120  assert(!operands.empty());
121 
122  // Cannot convert ops if their operands are not of LLVM type.
123  if (!llvm::all_of(operands.getTypes(),
124  [](Type t) { return isCompatibleType(t); }))
125  return failure();
126 
127  auto llvmNDVectorTy = operands[0].getType();
128  if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
129  return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
130 
131  auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
132  ValueRange operands) {
133  OperationState state(op->getLoc(), targetOp);
134  state.addTypes(llvm1DVectorTy);
135  state.addOperands(operands);
136  state.addAttributes(op->getAttrs());
137  return rewriter.createOperation(state)->getResult(0);
138  };
139 
140  return handleMultidimensionalVectors(op, operands, typeConverter, callback,
141  rewriter);
142 }
Include the generated interface declarations.
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, LLVMTypeConverter &converter)
Operation is a basic unit of execution within MLIR.
Definition: Operation.h:28
operand_range getOperands()
Returns an iterator on the underlying Value&#39;s.
Definition: Operation.h:247
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results)
Convert the given type.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:308
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
SmallVector< int64_t, 4 > arraySizes
Definition: VectorPattern.h:27
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:220
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing the results of an operation.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:752
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Operation * createOperation(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:380
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:206
OpResult getResult(unsigned idx)
Get the &#39;idx&#39;th result of this operation.
Definition: Operation.h:276
void addTypes(ArrayRef< Type > newTypes)
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:106
This represents an operation in an abstracted form, suitable for use with the builder APIs...
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
auto getType() const
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
LLVM dialect array type.
Definition: LLVMTypes.h:74
unsigned getNumElements() const
Returns the number of elements in the array type.
Definition: LLVMTypes.cpp:54
Type getType() const
Return the type of this value.
Definition: Value.h:117
type_range getTypes() const
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect. ...
Definition: LLVMTypes.cpp:762
Conversion from types in the Standard dialect to the LLVM IR dialect.
Definition: TypeConverter.h:30
static VectorType vectorType(CodeGen &codegen, Type etp)
Constructs vector type.
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayAttr)> fun)
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands...
Definition: Pattern.cpp:309
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
This class helps build Operations.
Definition: Builders.h:177
This class provides an abstraction over the different types of ranges over Values.
U cast() const
Definition: Types.h:250