MLIR  18.0.0git
Utils.cpp
Go to the documentation of this file.
1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "llvm/ADT/SmallBitVector.h"
18 
19 using namespace mlir;
20 
21 /// Matches a ConstantIndexOp.
22 /// TODO: This should probably just be a general matcher that uses matchConstant
23 /// and checks the operation for an index type.
26 }
27 
28 llvm::SmallBitVector mlir::getPositionsOfShapeOne(unsigned rank,
29  ArrayRef<int64_t> shape) {
30  llvm::SmallBitVector dimsToProject(shape.size());
31  for (unsigned pos = 0, e = shape.size(); pos < e && rank > 0; ++pos) {
32  if (shape[pos] == 1) {
33  dimsToProject.set(pos);
34  --rank;
35  }
36  }
37  return dimsToProject;
38 }
39 
41  OpFoldResult ofr) {
42  if (auto value = llvm::dyn_cast_if_present<Value>(ofr))
43  return value;
44  auto attr = dyn_cast<IntegerAttr>(llvm::dyn_cast_if_present<Attribute>(ofr));
45  assert(attr && "expect the op fold result casts to an integer attribute");
46  return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue());
47 }
48 
50  Type targetType, Value value) {
51  if (targetType == value.getType())
52  return value;
53 
54  bool targetIsIndex = targetType.isIndex();
55  bool valueIsIndex = value.getType().isIndex();
56  if (targetIsIndex ^ valueIsIndex)
57  return b.create<arith::IndexCastOp>(loc, targetType, value);
58 
59  auto targetIntegerType = dyn_cast<IntegerType>(targetType);
60  auto valueIntegerType = dyn_cast<IntegerType>(value.getType());
61  assert(targetIntegerType && valueIntegerType &&
62  "unexpected cast between types other than integers and index");
63  assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness());
64 
65  if (targetIntegerType.getWidth() > valueIntegerType.getWidth())
66  return b.create<arith::ExtSIOp>(loc, targetIntegerType, value);
67  return b.create<arith::TruncIOp>(loc, targetIntegerType, value);
68 }
69 
71  IntegerType toType, bool isUnsigned) {
72  // If operand is floating point, cast directly to the int type.
73  if (isa<FloatType>(operand.getType())) {
74  if (isUnsigned)
75  return b.create<arith::FPToUIOp>(toType, operand);
76  return b.create<arith::FPToSIOp>(toType, operand);
77  }
78  // Cast index operands directly to the int type.
79  if (operand.getType().isIndex())
80  return b.create<arith::IndexCastOp>(toType, operand);
81  if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) {
82  // Either extend or truncate.
83  if (toType.getWidth() > fromIntType.getWidth()) {
84  if (isUnsigned)
85  return b.create<arith::ExtUIOp>(toType, operand);
86  return b.create<arith::ExtSIOp>(toType, operand);
87  }
88  if (toType.getWidth() < fromIntType.getWidth())
89  return b.create<arith::TruncIOp>(toType, operand);
90  return operand;
91  }
92 
93  return {};
94 }
95 
97  FloatType toType, bool isUnsigned) {
98  // If operand is integer, cast directly to the float type.
99  // Note that it is unclear how to cast from BF16<->FP16.
100  if (isa<IntegerType>(operand.getType())) {
101  if (isUnsigned)
102  return b.create<arith::UIToFPOp>(toType, operand);
103  return b.create<arith::SIToFPOp>(toType, operand);
104  }
105  if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) {
106  if (toType.getWidth() > fromFpTy.getWidth())
107  return b.create<arith::ExtFOp>(toType, operand);
108  if (toType.getWidth() < fromFpTy.getWidth())
109  return b.create<arith::TruncFOp>(toType, operand);
110  return operand;
111  }
112 
113  return {};
114 }
115 
117  ComplexType targetType,
118  bool isUnsigned) {
119  if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) {
120  if (isa<FloatType>(targetType.getElementType()) &&
121  isa<FloatType>(fromComplexType.getElementType())) {
122  Value real = b.create<complex::ReOp>(operand);
123  Value imag = b.create<complex::ImOp>(operand);
124  Type targetETy = targetType.getElementType();
125  if (targetType.getElementType().getIntOrFloatBitWidth() <
126  fromComplexType.getElementType().getIntOrFloatBitWidth()) {
127  real = b.create<arith::TruncFOp>(targetETy, real);
128  imag = b.create<arith::TruncFOp>(targetETy, imag);
129  } else {
130  real = b.create<arith::ExtFOp>(targetETy, real);
131  imag = b.create<arith::ExtFOp>(targetETy, imag);
132  }
133  return b.create<complex::CreateOp>(targetType, real, imag);
134  }
135  }
136 
137  if (dyn_cast<FloatType>(operand.getType())) {
138  FloatType toFpTy = cast<FloatType>(targetType.getElementType());
139  auto toBitwidth = toFpTy.getIntOrFloatBitWidth();
140  Value from = operand;
141  if (from.getType().getIntOrFloatBitWidth() < toBitwidth) {
142  from = b.create<arith::ExtFOp>(toFpTy, from);
143  }
144  if (from.getType().getIntOrFloatBitWidth() > toBitwidth) {
145  from = b.create<arith::TruncFOp>(toFpTy, from);
146  }
148  mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
149  return b.create<complex::CreateOp>(targetType, from, zero);
150  }
151 
152  if (dyn_cast<IntegerType>(operand.getType())) {
153  FloatType toFpTy = cast<FloatType>(targetType.getElementType());
154  Value from = operand;
155  if (isUnsigned) {
156  from = b.create<arith::UIToFPOp>(toFpTy, from);
157  } else {
158  from = b.create<arith::SIToFPOp>(toFpTy, from);
159  }
161  mlir::APFloat(toFpTy.getFloatSemantics(), 0), toFpTy);
162  return b.create<complex::CreateOp>(targetType, from, zero);
163  }
164 
165  return {};
166 }
167 
169  Type toType, bool isUnsignedCast) {
170  if (operand.getType() == toType)
171  return operand;
172  ImplicitLocOpBuilder ib(loc, b);
173  Value result;
174  if (auto intTy = dyn_cast<IntegerType>(toType)) {
175  result = convertScalarToIntDtype(ib, operand, intTy, isUnsignedCast);
176  } else if (auto floatTy = dyn_cast<FloatType>(toType)) {
177  result = convertScalarToFpDtype(ib, operand, floatTy, isUnsignedCast);
178  } else if (auto complexTy = dyn_cast<ComplexType>(toType)) {
179  result =
180  convertScalarToComplexDtype(ib, operand, complexTy, isUnsignedCast);
181  }
182 
183  if (result)
184  return result;
185 
186  emitWarning(loc) << "could not cast operand of type " << operand.getType()
187  << " to " << toType;
188  return operand;
189 }
190 
193  ArrayRef<OpFoldResult> valueOrAttrVec) {
194  return llvm::to_vector<4>(
195  llvm::map_range(valueOrAttrVec, [&](OpFoldResult value) -> Value {
196  return getValueOrCreateConstantIndexOp(b, loc, value);
197  }));
198 }
199 
201  return b.create<arith::AndIOp>(loc, lhs, rhs);
202 }
204  if (isa<FloatType>(lhs.getType()))
205  return b.create<arith::AddFOp>(loc, lhs, rhs);
206  return b.create<arith::AddIOp>(loc, lhs, rhs);
207 }
209  if (isa<FloatType>(lhs.getType()))
210  return b.create<arith::SubFOp>(loc, lhs, rhs);
211  return b.create<arith::SubIOp>(loc, lhs, rhs);
212 }
214  if (isa<FloatType>(lhs.getType()))
215  return b.create<arith::MulFOp>(loc, lhs, rhs);
216  return b.create<arith::MulIOp>(loc, lhs, rhs);
217 }
219  if (isa<FloatType>(lhs.getType()))
220  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs);
221  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs);
222 }
224  if (isa<FloatType>(lhs.getType()))
225  return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs);
226  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs);
227 }
229  return b.create<arith::SelectOp>(loc, cmp, lhs, rhs);
230 }
static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, ComplexType targetType, bool isUnsigned)
Definition: Utils.cpp:116
static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, IntegerType toType, bool isUnsigned)
Definition: Utils.cpp:70
static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, FloatType toType, bool isUnsigned)
Definition: Utils.cpp:96
const llvm::fltSemantics & getFloatSemantics()
Return the floating semantics of this float type.
unsigned getWidth()
Return the bitwidth of this float type.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
OpTy create(Args &&...args)
Create an operation of specific op type at the current insertion point and location.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
This class helps build Operations.
Definition: Builders.h:206
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:123
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:93
Type getType() const
Return the type of this value.
Definition: Value.h:122
Specialization of arith.constant op that returns a floating point value.
Definition: Arith.h:74
Specialization of arith.constant op that returns an integer of index type.
Definition: Arith.h:90
This header declares functions that assist transformations in the MemRef dialect.
InFlightDiagnostic emitWarning(Location loc)
Utility method to emit a warning message using this location.
Value convertScalarToDtype(OpBuilder &b, Location loc, Value operand, Type toType, bool isUnsignedCast)
Converts a scalar value operand to type toType.
Definition: Utils.cpp:168
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:49
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
llvm::SmallBitVector getPositionsOfShapeOne(unsigned rank, ArrayRef< int64_t > shape)
Definition: Utils.cpp:28
detail::op_matcher< arith::ConstantIndexOp > matchConstantIndex()
Matches a ConstantIndexOp.
Definition: Utils.cpp:24
Value mul(Value lhs, Value rhs)
Definition: Utils.cpp:213
Value _and(Value lhs, Value rhs)
Definition: Utils.cpp:200
Value slt(Value lhs, Value rhs)
Definition: Utils.cpp:223
Value select(Value cmp, Value lhs, Value rhs)
Definition: Utils.cpp:228
Value add(Value lhs, Value rhs)
Definition: Utils.cpp:203
Value sgt(Value lhs, Value rhs)
Definition: Utils.cpp:218
Value sub(Value lhs, Value rhs)
Definition: Utils.cpp:208
The matcher that matches a certain kind of op.
Definition: Matchers.h:224