MLIR  20.0.0git
PolynomialOps.cpp
Go to the documentation of this file.
1 //===- PolynomialOps.cpp - Polynomial dialect ops ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
14 #include "mlir/IR/Builders.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "llvm/ADT/APInt.h"
19 
20 using namespace mlir;
21 using namespace mlir::polynomial;
22 
23 void FromTensorOp::build(OpBuilder &builder, OperationState &result,
24  Value input, RingAttr ring) {
25  TensorType tensorType = dyn_cast<TensorType>(input.getType());
26  auto bitWidth = tensorType.getElementTypeBitWidth();
27  APInt cmod(1 + bitWidth, 1);
28  cmod = cmod << bitWidth;
29  Type resultType = PolynomialType::get(builder.getContext(), ring);
30  build(builder, result, resultType, input);
31 }
32 
33 LogicalResult FromTensorOp::verify() {
34  ArrayRef<int64_t> tensorShape = getInput().getType().getShape();
35  RingAttr ring = getOutput().getType().getRing();
36  IntPolynomialAttr polyMod = ring.getPolynomialModulus();
37  if (polyMod) {
38  unsigned polyDegree = polyMod.getPolynomial().getDegree();
39  bool compatible = tensorShape.size() == 1 && tensorShape[0] <= polyDegree;
40  if (!compatible) {
41  InFlightDiagnostic diag = emitOpError()
42  << "input type " << getInput().getType()
43  << " does not match output type "
44  << getOutput().getType();
45  diag.attachNote()
46  << "the input type must be a tensor of shape [d] where d "
47  "is at most the degree of the polynomialModulus of "
48  "the output type's ring attribute";
49  return diag;
50  }
51  }
52 
53  unsigned inputBitWidth = getInput().getType().getElementTypeBitWidth();
54  if (inputBitWidth > ring.getCoefficientType().getIntOrFloatBitWidth()) {
55  InFlightDiagnostic diag = emitOpError()
56  << "input tensor element type "
57  << getInput().getType().getElementType()
58  << " is too large to fit in the coefficients of "
59  << getOutput().getType();
60  diag.attachNote() << "the input tensor's elements must be rescaled"
61  " to fit before using from_tensor";
62  return diag;
63  }
64 
65  return success();
66 }
67 
68 LogicalResult ToTensorOp::verify() {
69  ArrayRef<int64_t> tensorShape = getOutput().getType().getShape();
70  IntPolynomialAttr polyMod =
71  getInput().getType().getRing().getPolynomialModulus();
72  if (polyMod) {
73  unsigned polyDegree = polyMod.getPolynomial().getDegree();
74  bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
75 
76  if (compatible)
77  return success();
78 
79  InFlightDiagnostic diag = emitOpError()
80  << "input type " << getInput().getType()
81  << " does not match output type "
82  << getOutput().getType();
83  diag.attachNote()
84  << "the output type must be a tensor of shape [d] where d "
85  "is at most the degree of the polynomialModulus of "
86  "the input type's ring attribute";
87  return diag;
88  }
89 
90  return success();
91 }
92 
93 LogicalResult MulScalarOp::verify() {
94  Type argType = getPolynomial().getType();
95  PolynomialType polyType;
96 
97  if (auto shapedPolyType = dyn_cast<ShapedType>(argType)) {
98  polyType = cast<PolynomialType>(shapedPolyType.getElementType());
99  } else {
100  polyType = cast<PolynomialType>(argType);
101  }
102 
103  Type coefficientType = polyType.getRing().getCoefficientType();
104 
105  if (coefficientType != getScalar().getType())
106  return emitOpError() << "polynomial coefficient type " << coefficientType
107  << " does not match scalar type "
108  << getScalar().getType();
109 
110  return success();
111 }
112 
113 /// Test if a value is a primitive nth root of unity modulo cmod.
114 bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
115  const APInt &cmod) {
116  // The first or subsequent multiplications, may overflow the input bit width,
117  // so scale them up to ensure they do not overflow.
118  unsigned requiredBitWidth =
119  std::max(root.getActiveBits() * 2, cmod.getActiveBits() * 2);
120  APInt r = APInt(root).zextOrTrunc(requiredBitWidth);
121  APInt cmodExt = APInt(cmod).zextOrTrunc(requiredBitWidth);
122  assert(r.ule(cmodExt) && "root must be less than cmod");
123  uint64_t upperBound = n.getZExtValue();
124 
125  APInt a = r;
126  for (size_t k = 1; k < upperBound; k++) {
127  if (a.isOne())
128  return false;
129  a = (a * r).urem(cmodExt);
130  }
131  return a.isOne();
132 }
133 
134 /// Verify that the types involved in an NTT or INTT operation are
135 /// compatible.
136 static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
137  RankedTensorType tensorType,
138  std::optional<PrimitiveRootAttr> root) {
139  Attribute encoding = tensorType.getEncoding();
140  if (!encoding) {
141  return op->emitOpError()
142  << "expects a ring encoding to be provided to the tensor";
143  }
144  auto encodedRing = dyn_cast<RingAttr>(encoding);
145  if (!encodedRing) {
146  return op->emitOpError()
147  << "the provided tensor encoding is not a ring attribute";
148  }
149 
150  if (encodedRing != ring) {
151  return op->emitOpError()
152  << "encoded ring type " << encodedRing
153  << " is not equivalent to the polynomial ring " << ring;
154  }
155 
156  unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
157  ArrayRef<int64_t> tensorShape = tensorType.getShape();
158  bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
159  if (!compatible) {
161  << "tensor type " << tensorType
162  << " does not match output type " << ring;
163  diag.attachNote() << "the tensor must have shape [d] where d "
164  "is exactly the degree of the polynomialModulus of "
165  "the polynomial type's ring attribute";
166  return diag;
167  }
168 
169  if (root.has_value()) {
170  APInt rootValue = root.value().getValue().getValue();
171  APInt rootDegree = root.value().getDegree().getValue();
172  APInt cmod = ring.getCoefficientModulus().getValue();
173  if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
174  return op->emitOpError()
175  << "provided root " << rootValue.getZExtValue()
176  << " is not a primitive root "
177  << "of unity mod " << cmod.getZExtValue()
178  << ", with the specified degree " << rootDegree.getZExtValue();
179  }
180  }
181 
182  return success();
183 }
184 
185 LogicalResult NTTOp::verify() {
186  return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
187  getOutput().getType(), getRoot());
188 }
189 
190 LogicalResult INTTOp::verify() {
191  return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
192  getInput().getType(), getRoot());
193 }
194 
195 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
196  // Using the built-in parser.parseAttribute requires the full
197  // #polynomial.typed_int_polynomial syntax, which is excessive.
198  // Instead we parse a keyword int to signal it's an integer polynomial
199  Type type;
200  if (succeeded(parser.parseOptionalKeyword("float"))) {
201  Attribute floatPolyAttr = FloatPolynomialAttr::parse(parser, nullptr);
202  if (floatPolyAttr) {
203  if (parser.parseColon() || parser.parseType(type))
204  return failure();
205  result.addAttribute("value",
206  TypedFloatPolynomialAttr::get(type, floatPolyAttr));
207  result.addTypes(type);
208  return success();
209  }
210  }
211 
212  if (succeeded(parser.parseOptionalKeyword("int"))) {
213  Attribute intPolyAttr = IntPolynomialAttr::parse(parser, nullptr);
214  if (intPolyAttr) {
215  if (parser.parseColon() || parser.parseType(type))
216  return failure();
217 
218  result.addAttribute("value",
219  TypedIntPolynomialAttr::get(type, intPolyAttr));
220  result.addTypes(type);
221  return success();
222  }
223  }
224 
225  // In the worst case, still accept the verbose versions.
226  TypedIntPolynomialAttr typedIntPolyAttr;
227  OptionalParseResult res =
228  parser.parseOptionalAttribute<TypedIntPolynomialAttr>(
229  typedIntPolyAttr, "value", result.attributes);
230  if (res.has_value() && succeeded(res.value())) {
231  result.addTypes(typedIntPolyAttr.getType());
232  return success();
233  }
234 
235  TypedFloatPolynomialAttr typedFloatPolyAttr;
236  res = parser.parseAttribute<TypedFloatPolynomialAttr>(
237  typedFloatPolyAttr, "value", result.attributes);
238  if (res.has_value() && succeeded(res.value())) {
239  result.addTypes(typedFloatPolyAttr.getType());
240  return success();
241  }
242 
243  return failure();
244 }
245 
247  p << " ";
248  if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(getValue())) {
249  p << "int";
250  intPoly.getValue().print(p);
251  } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(getValue())) {
252  p << "float";
253  floatPoly.getValue().print(p);
254  } else {
255  assert(false && "unexpected attribute type");
256  }
257  p << " : ";
258  p.printType(getOutput().getType());
259 }
260 
261 LogicalResult ConstantOp::inferReturnTypes(
262  MLIRContext *context, std::optional<mlir::Location> location,
263  ConstantOp::Adaptor adaptor,
264  llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
265  Attribute operand = adaptor.getValue();
266  if (auto intPoly = dyn_cast<TypedIntPolynomialAttr>(operand)) {
267  inferredReturnTypes.push_back(intPoly.getType());
268  } else if (auto floatPoly = dyn_cast<TypedFloatPolynomialAttr>(operand)) {
269  inferredReturnTypes.push_back(floatPoly.getType());
270  } else {
271  assert(false && "unexpected attribute type");
272  return failure();
273  }
274  return success();
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // TableGen'd canonicalization patterns
279 //===----------------------------------------------------------------------===//
280 
281 namespace {
282 #include "PolynomialCanonicalization.inc"
283 } // namespace
284 
285 void SubOp::getCanonicalizationPatterns(RewritePatternSet &results,
286  MLIRContext *context) {
287  results.add<SubAsAdd>(context);
288 }
289 
290 void NTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
291  MLIRContext *context) {
292  results.add<NTTAfterINTT>(context);
293 }
294 
295 void INTTOp::getCanonicalizationPatterns(RewritePatternSet &results,
296  MLIRContext *context) {
297  results.add<INTTAfterNTT>(context);
298 }
static std::string diag(const llvm::Value &value)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n, const APInt &cmod)
Test if a value is a primitive nth root of unity modulo cmod.
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring, RankedTensorType tensorType, std::optional< PrimitiveRootAttr > root)
Verify that the types involved in an NTT or INTT operation are compatible.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
MLIRContext * getContext() const
Definition: Builders.h:56
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
Definition: Builders.h:216
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
ParseResult value() const
Access the internal ParseResult value.
Definition: OpDefinition.h:52
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Definition: BuiltinTypes.h:102
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
NamedAttrList attributes